orbmem 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbmem/__init__.py +3 -0
- orbmem/core/auth.py +26 -0
- orbmem/core/config.py +123 -0
- orbmem/core/ocdb.py +126 -0
- orbmem/db/mongo.py +25 -0
- orbmem/db/neo4j.py +31 -0
- orbmem/db/postgres.py +36 -0
- orbmem/db/redis.py +25 -0
- orbmem/engines/base_engine.py +47 -0
- orbmem/engines/graph/neo4j_backend.py +70 -0
- orbmem/engines/memory/postgres_backend.py +112 -0
- orbmem/engines/memory/redis_backend.py +85 -0
- orbmem/engines/safety/mongo_backend.py +97 -0
- orbmem/engines/safety/timeseries_backend.py +36 -0
- orbmem/engines/vector/qdrant_backend.py +75 -0
- orbmem/models/__init__.py +11 -0
- orbmem/models/fingerprints.py +13 -0
- orbmem/models/memory.py +15 -0
- orbmem/models/safety.py +19 -0
- orbmem/utils/embeddings.py +34 -0
- orbmem/utils/exceptions.py +30 -0
- orbmem/utils/helpers.py +54 -0
- orbmem/utils/logger.py +38 -0
- orbmem/utils/validators.py +53 -0
- orbmem-1.0.4.dist-info/METADATA +177 -0
- orbmem-1.0.4.dist-info/RECORD +29 -0
- orbmem-1.0.4.dist-info/WHEEL +5 -0
- orbmem-1.0.4.dist-info/licenses/LICENSE +23 -0
- orbmem-1.0.4.dist-info/top_level.txt +1 -0
orbmem/__init__.py
ADDED
orbmem/core/auth.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# core/auth.py
|
|
2
|
+
|
|
3
|
+
from fastapi import Request
|
|
4
|
+
from orbmem.utils.exceptions import AuthError
|
|
5
|
+
from orbmem.core.config import load_config
|
|
6
|
+
|
|
7
|
+
def validate_api_key(request: Request):
|
|
8
|
+
"""
|
|
9
|
+
Validates API key using latest .env values.
|
|
10
|
+
Reloads config every time to avoid stale values.
|
|
11
|
+
"""
|
|
12
|
+
config = load_config() # <-- reloads fresh env each request
|
|
13
|
+
|
|
14
|
+
auth_header = request.headers.get("Authorization")
|
|
15
|
+
if not auth_header:
|
|
16
|
+
raise AuthError("Missing Authorization header")
|
|
17
|
+
|
|
18
|
+
if not auth_header.startswith("Bearer "):
|
|
19
|
+
raise AuthError("Invalid Authorization format. Use: Bearer <API_KEY>")
|
|
20
|
+
|
|
21
|
+
api_key = auth_header.replace("Bearer ", "").strip()
|
|
22
|
+
|
|
23
|
+
if api_key not in config.api.api_keys:
|
|
24
|
+
raise AuthError("Invalid API key. Access denied.")
|
|
25
|
+
|
|
26
|
+
return True
|
orbmem/core/config.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# core/config.py
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
|
|
7
|
+
from dotenv import load_dotenv
|
|
8
|
+
from orbmem.utils.exceptions import ConfigError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# ===========================
|
|
12
|
+
# DATA CLASSES
|
|
13
|
+
# ===========================
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class DatabaseConfig:
|
|
17
|
+
postgres_url: str
|
|
18
|
+
redis_url: Optional[str] = None
|
|
19
|
+
mongo_url: Optional[str] = None
|
|
20
|
+
neo4j_url: Optional[str] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class APIConfig:
|
|
25
|
+
api_keys: List[str]
|
|
26
|
+
debug: bool = False
|
|
27
|
+
mode: str = "local" # local | cloud
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class OCDBConfig:
|
|
32
|
+
db: DatabaseConfig
|
|
33
|
+
api: APIConfig
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ===========================
|
|
37
|
+
# HELPERS
|
|
38
|
+
# ===========================
|
|
39
|
+
|
|
40
|
+
def _get_env(name: str, default: Optional[str] = None, required: bool = False) -> Optional[str]:
|
|
41
|
+
"""Fetch env variable with optional requirement."""
|
|
42
|
+
value = os.getenv(name, default)
|
|
43
|
+
|
|
44
|
+
# treat empty string as None
|
|
45
|
+
if value is not None:
|
|
46
|
+
value = value.strip()
|
|
47
|
+
if value == "":
|
|
48
|
+
value = None
|
|
49
|
+
|
|
50
|
+
if required and not value:
|
|
51
|
+
raise ConfigError(f"Missing required environment variable: {name}")
|
|
52
|
+
|
|
53
|
+
return value
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# ===========================
|
|
57
|
+
# MAIN CONFIG LOADER
|
|
58
|
+
# ===========================
|
|
59
|
+
|
|
60
|
+
def load_config() -> OCDBConfig:
|
|
61
|
+
"""
|
|
62
|
+
Load configuration fresh every call.
|
|
63
|
+
Respects OCDB_MODE (local/cloud).
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
# Reload .env every time (Windows safe)
|
|
67
|
+
load_dotenv(override=True)
|
|
68
|
+
|
|
69
|
+
# -------------------------------
|
|
70
|
+
# MODE: local or cloud
|
|
71
|
+
# -------------------------------
|
|
72
|
+
mode = _get_env("OCDB_MODE", default="local").lower()
|
|
73
|
+
if mode not in ("local", "cloud"):
|
|
74
|
+
raise ConfigError("OCDB_MODE must be either 'local' or 'cloud'")
|
|
75
|
+
|
|
76
|
+
# -------------------------------
|
|
77
|
+
# DATABASES
|
|
78
|
+
# -------------------------------
|
|
79
|
+
postgres_url = _get_env("POSTGRES_URL", required=True)
|
|
80
|
+
redis_url = _get_env("REDIS_URL")
|
|
81
|
+
mongo_url = _get_env("MONGO_URL")
|
|
82
|
+
neo4j_url = _get_env("NEO4J_URL") # will be None in local mode
|
|
83
|
+
|
|
84
|
+
# -------------------------------
|
|
85
|
+
# API Keys (only required in cloud mode)
|
|
86
|
+
# -------------------------------
|
|
87
|
+
api_keys_raw = _get_env(
|
|
88
|
+
"OCDB_API_KEYS",
|
|
89
|
+
required=(mode == "cloud") # Only required in cloud mode
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
api_keys = []
|
|
93
|
+
if api_keys_raw:
|
|
94
|
+
api_keys = [k.strip() for k in api_keys_raw.split(",") if k.strip()]
|
|
95
|
+
|
|
96
|
+
# -------------------------------
|
|
97
|
+
# DEBUG MODE
|
|
98
|
+
# -------------------------------
|
|
99
|
+
debug_raw = _get_env("OCDB_DEBUG", default="0")
|
|
100
|
+
debug = debug_raw.lower() in ("1", "true", "yes", "y")
|
|
101
|
+
|
|
102
|
+
# -------------------------------
|
|
103
|
+
# Build config objects
|
|
104
|
+
# -------------------------------
|
|
105
|
+
db_cfg = DatabaseConfig(
|
|
106
|
+
postgres_url=postgres_url,
|
|
107
|
+
redis_url=redis_url,
|
|
108
|
+
mongo_url=mongo_url,
|
|
109
|
+
neo4j_url=neo4j_url,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
api_cfg = APIConfig(
|
|
113
|
+
api_keys=api_keys,
|
|
114
|
+
debug=debug,
|
|
115
|
+
mode=mode,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
print(f"🔧 Loaded OCDB_MODE: {mode}")
|
|
119
|
+
print(f"🔐 Loaded API keys: {api_keys}")
|
|
120
|
+
print(f"🐬 Redis enabled: {bool(redis_url)}")
|
|
121
|
+
print(f"🧠 Neo4j enabled: {bool(neo4j_url)}")
|
|
122
|
+
|
|
123
|
+
return OCDBConfig(db=db_cfg, api=api_cfg)
|
orbmem/core/ocdb.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# core/ocdb.py
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional, List
|
|
4
|
+
from .config import load_config
|
|
5
|
+
|
|
6
|
+
# MEMORY ENGINE (SQLite/Postgres hybrid)
|
|
7
|
+
from orbmem.engines.memory.postgres_backend import PostgresMemoryBackend
|
|
8
|
+
|
|
9
|
+
# VECTOR ENGINE – Qdrant fallback to FAISS / in-memory
|
|
10
|
+
from orbmem.engines.vector.qdrant_backend import QdrantVectorBackend
|
|
11
|
+
|
|
12
|
+
# GRAPH ENGINE – fallback to NetworkX (in-memory)
|
|
13
|
+
from orbmem.engines.graph.neo4j_backend import Neo4jGraphBackend
|
|
14
|
+
|
|
15
|
+
# SAFETY ENGINE
|
|
16
|
+
from orbmem.engines.safety.mongo_backend import MongoSafetyBackend
|
|
17
|
+
from orbmem.engines.safety.timeseries_backend import TimeSeriesSafetyBackend
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OCDB:
|
|
21
|
+
"""
|
|
22
|
+
High-level interface combining all cognitive engines:
|
|
23
|
+
- Memory Engine (Postgres only, Redis disabled)
|
|
24
|
+
- Vector Engine (FAISS fallback)
|
|
25
|
+
- Graph Engine (NetworkX fallback)
|
|
26
|
+
- Safety Engine (Mongo + TimeSeries)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self):
|
|
30
|
+
cfg = load_config()
|
|
31
|
+
|
|
32
|
+
# -------------------------------
|
|
33
|
+
# MEMORY ENGINE (NO REDIS)
|
|
34
|
+
# -------------------------------
|
|
35
|
+
self.pg_memory = PostgresMemoryBackend()
|
|
36
|
+
self.redis_memory = None # Fully disabled
|
|
37
|
+
|
|
38
|
+
# -------------------------------
|
|
39
|
+
# VECTOR ENGINE
|
|
40
|
+
# -------------------------------
|
|
41
|
+
self.vector_engine = QdrantVectorBackend()
|
|
42
|
+
|
|
43
|
+
# -------------------------------
|
|
44
|
+
# GRAPH ENGINE
|
|
45
|
+
# -------------------------------
|
|
46
|
+
self.graph_engine = Neo4jGraphBackend()
|
|
47
|
+
|
|
48
|
+
# -------------------------------
|
|
49
|
+
# SAFETY ENGINE
|
|
50
|
+
# -------------------------------
|
|
51
|
+
self.safety_event_engine = MongoSafetyBackend()
|
|
52
|
+
self.safety_timeseries = TimeSeriesSafetyBackend()
|
|
53
|
+
|
|
54
|
+
# =====================================================
|
|
55
|
+
# MEMORY METHODS — SAFE & CLEAN (NO REDIS)
|
|
56
|
+
# =====================================================
|
|
57
|
+
|
|
58
|
+
def memory_set(self, key: str, value: dict, session_id: str = None, ttl_seconds: int = None):
|
|
59
|
+
"""
|
|
60
|
+
Store memory in Postgres only (Redis disabled).
|
|
61
|
+
"""
|
|
62
|
+
return self.pg_memory.set(key, value, session_id=session_id, ttl_seconds=ttl_seconds)
|
|
63
|
+
|
|
64
|
+
def memory_get(self, key: str):
|
|
65
|
+
"""
|
|
66
|
+
Retrieve memory from Postgres only.
|
|
67
|
+
"""
|
|
68
|
+
return self.pg_memory.get(key)
|
|
69
|
+
|
|
70
|
+
def memory_keys(self) -> List[str]:
|
|
71
|
+
"""
|
|
72
|
+
List all memory keys.
|
|
73
|
+
"""
|
|
74
|
+
return self.pg_memory.keys()
|
|
75
|
+
|
|
76
|
+
# =====================================================
|
|
77
|
+
# VECTOR METHODS
|
|
78
|
+
# =====================================================
|
|
79
|
+
|
|
80
|
+
def vector_search(self, query: str, k: int = 5):
|
|
81
|
+
"""
|
|
82
|
+
Search vector embeddings using FAISS fallback (no Qdrant needed).
|
|
83
|
+
"""
|
|
84
|
+
return self.vector_engine.search(query, k=k)
|
|
85
|
+
|
|
86
|
+
# =====================================================
|
|
87
|
+
# GRAPH METHODS
|
|
88
|
+
# =====================================================
|
|
89
|
+
|
|
90
|
+
def graph_add(self, node_id: str, content: str, parent: Optional[str] = None):
|
|
91
|
+
"""
|
|
92
|
+
Add a node to in-memory reasoning graph.
|
|
93
|
+
"""
|
|
94
|
+
return self.graph_engine.add_node(node_id, content, parent)
|
|
95
|
+
|
|
96
|
+
def graph_path(self, start: str, end: str):
|
|
97
|
+
"""
|
|
98
|
+
Get path between two graph nodes.
|
|
99
|
+
"""
|
|
100
|
+
return self.graph_engine.get_path(start, end)
|
|
101
|
+
|
|
102
|
+
# =====================================================
|
|
103
|
+
# SAFETY METHODS
|
|
104
|
+
# =====================================================
|
|
105
|
+
|
|
106
|
+
def safety_scan(self, text: str):
|
|
107
|
+
"""
|
|
108
|
+
Detect safety events and record in timeseries engine.
|
|
109
|
+
"""
|
|
110
|
+
events = self.safety_event_engine.scan(text)
|
|
111
|
+
|
|
112
|
+
# update timeseries fingerprint
|
|
113
|
+
for evt in events:
|
|
114
|
+
self.safety_timeseries.add_point(evt.tag, evt.severity)
|
|
115
|
+
|
|
116
|
+
return [
|
|
117
|
+
{
|
|
118
|
+
"text": evt.text,
|
|
119
|
+
"tag": evt.tag,
|
|
120
|
+
"severity": evt.severity,
|
|
121
|
+
"correction": evt.correction,
|
|
122
|
+
"details": evt.details,
|
|
123
|
+
"timestamp": evt.timestamp,
|
|
124
|
+
}
|
|
125
|
+
for evt in events
|
|
126
|
+
]
|
orbmem/db/mongo.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# db/mongo.py
|
|
2
|
+
|
|
3
|
+
from pymongo import MongoClient
|
|
4
|
+
from orbmem.core.config import load_config
|
|
5
|
+
from orbmem.utils.logger import get_logger
|
|
6
|
+
from orbmem.utils.exceptions import DatabaseError
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
CONFIG = load_config()
|
|
11
|
+
MONGO_URL = CONFIG.db.mongo_url
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_mongo_client():
|
|
15
|
+
"""Get MongoDB client for safety engine."""
|
|
16
|
+
if not MONGO_URL:
|
|
17
|
+
return None
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
client = MongoClient(MONGO_URL)
|
|
21
|
+
logger.info("MongoDB connection established.")
|
|
22
|
+
return client
|
|
23
|
+
except Exception as e:
|
|
24
|
+
logger.error(f"MongoDB connection failed: {e}")
|
|
25
|
+
raise DatabaseError(f"MongoDB init error: {e}")
|
orbmem/db/neo4j.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# db/neo4j.py
|
|
2
|
+
|
|
3
|
+
from neo4j import GraphDatabase
|
|
4
|
+
from orbmem.utils.logger import get_logger
|
|
5
|
+
from orbmem.utils.exceptions import DatabaseError
|
|
6
|
+
from orbmem.core.config import load_config
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
_driver = None
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_neo4j_driver():
|
|
14
|
+
global _driver
|
|
15
|
+
|
|
16
|
+
if _driver:
|
|
17
|
+
return _driver
|
|
18
|
+
|
|
19
|
+
cfg = load_config()
|
|
20
|
+
url = cfg.db.neo4j_url
|
|
21
|
+
user = cfg.db.neo4j_user
|
|
22
|
+
password = cfg.db.neo4j_password
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
_driver = GraphDatabase.driver(url, auth=(user, password))
|
|
26
|
+
logger.info("Neo4j driver initialized.")
|
|
27
|
+
return _driver
|
|
28
|
+
|
|
29
|
+
except Exception as e:
|
|
30
|
+
logger.error(f"Neo4j initialization failed: {e}")
|
|
31
|
+
raise DatabaseError(f"Neo4j init error: {e}")
|
orbmem/db/postgres.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# db/postgres.py
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import create_engine
|
|
4
|
+
from sqlalchemy.orm import sessionmaker, declarative_base
|
|
5
|
+
from orbmem.core.config import load_config
|
|
6
|
+
from orbmem.utils.logger import get_logger
|
|
7
|
+
from orbmem.utils.exceptions import DatabaseError
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
# Load configuration
|
|
12
|
+
CONFIG = load_config()
|
|
13
|
+
POSTGRES_URL = CONFIG.db.postgres_url
|
|
14
|
+
|
|
15
|
+
# Create SQLAlchemy engine
|
|
16
|
+
try:
|
|
17
|
+
engine = create_engine(
|
|
18
|
+
POSTGRES_URL,
|
|
19
|
+
echo=False,
|
|
20
|
+
pool_pre_ping=True, # auto-detect dead connections
|
|
21
|
+
pool_recycle=1800 # refresh stale connections
|
|
22
|
+
)
|
|
23
|
+
logger.info("PostgreSQL engine initialized successfully.")
|
|
24
|
+
except Exception as e:
|
|
25
|
+
logger.error(f"Failed to initialize PostgreSQL engine: {e}")
|
|
26
|
+
raise DatabaseError(f"PostgreSQL init error: {e}")
|
|
27
|
+
|
|
28
|
+
# Create database session factory
|
|
29
|
+
SessionLocal = sessionmaker(
|
|
30
|
+
autocommit=False,
|
|
31
|
+
autoflush=False,
|
|
32
|
+
bind=engine
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Base class for all SQLAlchemy ORM models
|
|
36
|
+
Base = declarative_base()
|
orbmem/db/redis.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# db/redis.py
|
|
2
|
+
|
|
3
|
+
import redis
|
|
4
|
+
from orbmem.core.config import load_config
|
|
5
|
+
from orbmem.utils.logger import get_logger
|
|
6
|
+
from orbmem.utils.exceptions import DatabaseError
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
CONFIG = load_config()
|
|
11
|
+
REDIS_URL = CONFIG.db.redis_url
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_redis_client():
|
|
15
|
+
"""Create and return a Redis client instance."""
|
|
16
|
+
if not REDIS_URL:
|
|
17
|
+
return None
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
client = redis.Redis.from_url(REDIS_URL, decode_responses=True)
|
|
21
|
+
logger.info("Redis connection established.")
|
|
22
|
+
return client
|
|
23
|
+
except Exception as e:
|
|
24
|
+
logger.error(f"Redis connection failed: {e}")
|
|
25
|
+
raise DatabaseError(f"Redis init error: {e}")
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# engines/base_engine.py
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseEngine(ABC):
|
|
8
|
+
"""
|
|
9
|
+
Abstract base class for all OCDB engines.
|
|
10
|
+
Defines optional generic methods common across engines.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
# -------------------------------
|
|
14
|
+
# MEMORY-LIKE INTERFACES
|
|
15
|
+
# -------------------------------
|
|
16
|
+
def set(self, key: str, value: Any, **kwargs):
|
|
17
|
+
raise NotImplementedError("set() not implemented for this engine")
|
|
18
|
+
|
|
19
|
+
def get(self, key: str) -> Any:
|
|
20
|
+
raise NotImplementedError("get() not implemented for this engine")
|
|
21
|
+
|
|
22
|
+
def delete(self, key: str):
|
|
23
|
+
raise NotImplementedError("delete() not implemented for this engine")
|
|
24
|
+
|
|
25
|
+
# -------------------------------
|
|
26
|
+
# GRAPH-LIKE INTERFACES
|
|
27
|
+
# -------------------------------
|
|
28
|
+
def add_node(self, *args, **kwargs):
|
|
29
|
+
raise NotImplementedError("add_node() not implemented")
|
|
30
|
+
|
|
31
|
+
def get_path(self, *args, **kwargs):
|
|
32
|
+
raise NotImplementedError("get_path() not implemented")
|
|
33
|
+
|
|
34
|
+
# -------------------------------
|
|
35
|
+
# VECTOR-LIKE INTERFACES
|
|
36
|
+
# -------------------------------
|
|
37
|
+
def add_text(self, *args, **kwargs):
|
|
38
|
+
raise NotImplementedError("add_text() not implemented")
|
|
39
|
+
|
|
40
|
+
def search(self, *args, **kwargs):
|
|
41
|
+
raise NotImplementedError("search() not implemented")
|
|
42
|
+
|
|
43
|
+
# -------------------------------
|
|
44
|
+
# SAFETY-LIKE INTERFACES
|
|
45
|
+
# -------------------------------
|
|
46
|
+
def scan(self, *args, **kwargs):
|
|
47
|
+
raise NotImplementedError("scan() not implemented")
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# engines/graph/neo4j_backend.py
|
|
2
|
+
# In-memory Graph Engine using NetworkX (Neo4j replacement)
|
|
3
|
+
|
|
4
|
+
import networkx as nx
|
|
5
|
+
from typing import Dict, Any, List
|
|
6
|
+
from orbmem.utils.logger import get_logger
|
|
7
|
+
from orbmem.utils.exceptions import DatabaseError
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Neo4jGraphBackend:
|
|
13
|
+
"""
|
|
14
|
+
Lightweight in-memory graph engine using NetworkX.
|
|
15
|
+
Fully compatible with OCDB.graph_add() and OCDB.graph_path().
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self):
|
|
19
|
+
try:
|
|
20
|
+
self.graph = nx.DiGraph()
|
|
21
|
+
logger.info("In-memory Graph Engine initialized (NetworkX).")
|
|
22
|
+
except Exception as e:
|
|
23
|
+
raise DatabaseError(f"Failed to initialize graph engine: {e}")
|
|
24
|
+
|
|
25
|
+
# ---------------------------------------------------------
|
|
26
|
+
# MATCHES ocdb.graph_add(node_id, content, parent)
|
|
27
|
+
# ---------------------------------------------------------
|
|
28
|
+
def add_node(self, node_id: str, content: str, parent: str = None):
|
|
29
|
+
try:
|
|
30
|
+
# Add node with content property
|
|
31
|
+
self.graph.add_node(node_id, content=content)
|
|
32
|
+
logger.info(f"Graph node added: {node_id}")
|
|
33
|
+
|
|
34
|
+
# Add edge to parent if provided
|
|
35
|
+
if parent:
|
|
36
|
+
self.graph.add_edge(parent, node_id, relation="next")
|
|
37
|
+
logger.info(f"Graph edge added: {parent} -> {node_id}")
|
|
38
|
+
|
|
39
|
+
return {"node_id": node_id, "parent": parent}
|
|
40
|
+
|
|
41
|
+
except Exception as e:
|
|
42
|
+
raise DatabaseError(f"Failed adding node: {e}")
|
|
43
|
+
|
|
44
|
+
# ---------------------------------------------------------
|
|
45
|
+
# MATCHES ocdb.graph_path(start, end)
|
|
46
|
+
# ---------------------------------------------------------
|
|
47
|
+
def get_path(self, start: str, end: str) -> List[str]:
|
|
48
|
+
try:
|
|
49
|
+
return nx.shortest_path(self.graph, source=start, target=end)
|
|
50
|
+
except Exception:
|
|
51
|
+
return []
|
|
52
|
+
|
|
53
|
+
# ---------------------------------------------------------
|
|
54
|
+
# Debug export (optional)
|
|
55
|
+
# ---------------------------------------------------------
|
|
56
|
+
def export(self):
|
|
57
|
+
return {
|
|
58
|
+
"nodes": [
|
|
59
|
+
{"id": n, "properties": self.graph.nodes[n]}
|
|
60
|
+
for n in self.graph.nodes
|
|
61
|
+
],
|
|
62
|
+
"edges": [
|
|
63
|
+
{
|
|
64
|
+
"from": u,
|
|
65
|
+
"to": v,
|
|
66
|
+
"relation": self.graph.edges[u, v].get("relation", "next")
|
|
67
|
+
}
|
|
68
|
+
for u, v in self.graph.edges
|
|
69
|
+
]
|
|
70
|
+
}
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# engines/memory/postgres_backend.py
|
|
2
|
+
# Repurposed as SQLite backend (lightweight local mode)
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import sqlite3
|
|
6
|
+
from datetime import datetime, timedelta
|
|
7
|
+
from orbmem.utils.logger import get_logger
|
|
8
|
+
from orbmem.utils.exceptions import DatabaseError
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
DB_PATH = "ocdb.sqlite3"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PostgresMemoryBackend:
|
|
16
|
+
"""
|
|
17
|
+
Lightweight SQLite-based memory engine.
|
|
18
|
+
Replaces PostgreSQL for systems with low storage.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
try:
|
|
23
|
+
self.conn = sqlite3.connect(DB_PATH, check_same_thread=False)
|
|
24
|
+
self.cursor = self.conn.cursor()
|
|
25
|
+
self._init_tables()
|
|
26
|
+
logger.info("SQLite Memory Backend initialized (replacing PostgreSQL).")
|
|
27
|
+
except Exception as e:
|
|
28
|
+
raise DatabaseError(f"SQLite init error: {e}")
|
|
29
|
+
|
|
30
|
+
# ---------------------------------------------------------
|
|
31
|
+
# Create tables if missing
|
|
32
|
+
# ---------------------------------------------------------
|
|
33
|
+
def _init_tables(self):
|
|
34
|
+
self.cursor.execute("""
|
|
35
|
+
CREATE TABLE IF NOT EXISTS memory (
|
|
36
|
+
key TEXT PRIMARY KEY,
|
|
37
|
+
value TEXT,
|
|
38
|
+
session_id TEXT,
|
|
39
|
+
expires_at TEXT
|
|
40
|
+
)
|
|
41
|
+
""")
|
|
42
|
+
self.conn.commit()
|
|
43
|
+
|
|
44
|
+
# ---------------------------------------------------------
|
|
45
|
+
# Set memory key
|
|
46
|
+
# ---------------------------------------------------------
|
|
47
|
+
def set(self, key: str, value, session_id: str = None, ttl_seconds: int = None):
|
|
48
|
+
try:
|
|
49
|
+
expires_at = None
|
|
50
|
+
if ttl_seconds:
|
|
51
|
+
expires_at = (datetime.utcnow() + timedelta(seconds=ttl_seconds)).isoformat()
|
|
52
|
+
|
|
53
|
+
value_json = json.dumps(value)
|
|
54
|
+
|
|
55
|
+
self.cursor.execute("""
|
|
56
|
+
INSERT INTO memory (key, value, session_id, expires_at)
|
|
57
|
+
VALUES (?, ?, ?, ?)
|
|
58
|
+
ON CONFLICT(key) DO UPDATE SET
|
|
59
|
+
value=excluded.value,
|
|
60
|
+
session_id=excluded.session_id,
|
|
61
|
+
expires_at=excluded.expires_at
|
|
62
|
+
""", (key, value_json, session_id, expires_at))
|
|
63
|
+
|
|
64
|
+
self.conn.commit()
|
|
65
|
+
|
|
66
|
+
except Exception as e:
|
|
67
|
+
raise DatabaseError(f"SQLite write error: {e}")
|
|
68
|
+
|
|
69
|
+
# ---------------------------------------------------------
|
|
70
|
+
# Get memory key
|
|
71
|
+
# ---------------------------------------------------------
|
|
72
|
+
def get(self, key: str):
|
|
73
|
+
try:
|
|
74
|
+
self.cursor.execute("SELECT value, expires_at FROM memory WHERE key=?", (key,))
|
|
75
|
+
row = self.cursor.fetchone()
|
|
76
|
+
|
|
77
|
+
if not row:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
value_json, expires_at = row
|
|
81
|
+
|
|
82
|
+
# TTL check
|
|
83
|
+
if expires_at:
|
|
84
|
+
if datetime.utcnow() > datetime.fromisoformat(expires_at):
|
|
85
|
+
self.delete(key)
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
return json.loads(value_json)
|
|
89
|
+
|
|
90
|
+
except Exception as e:
|
|
91
|
+
raise DatabaseError(f"SQLite read error: {e}")
|
|
92
|
+
|
|
93
|
+
# ---------------------------------------------------------
|
|
94
|
+
# Delete a memory key
|
|
95
|
+
# ---------------------------------------------------------
|
|
96
|
+
def delete(self, key: str):
|
|
97
|
+
try:
|
|
98
|
+
self.cursor.execute("DELETE FROM memory WHERE key=?", (key,))
|
|
99
|
+
self.conn.commit()
|
|
100
|
+
except Exception as e:
|
|
101
|
+
raise DatabaseError(f"SQLite delete error: {e}")
|
|
102
|
+
|
|
103
|
+
# ---------------------------------------------------------
|
|
104
|
+
# List memory keys
|
|
105
|
+
# ---------------------------------------------------------
|
|
106
|
+
def keys(self):
|
|
107
|
+
try:
|
|
108
|
+
self.cursor.execute("SELECT key FROM memory")
|
|
109
|
+
rows = self.cursor.fetchall()
|
|
110
|
+
return [r[0] for r in rows]
|
|
111
|
+
except Exception as e:
|
|
112
|
+
raise DatabaseError(f"SQLite keys error: {e}")
|