orbmem 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,85 @@
1
+ # engines/memory/redis_backend.py
2
+
3
+ import json
4
+ from typing import Any, Dict, Optional, List
5
+ from orbmem.db.redis import get_redis_client
6
+ from orbmem.utils.logger import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ class RedisMemoryBackend:
12
+ """
13
+ High-speed TTL memory using Redis.
14
+ Best for short-term memory, active sessions, caching.
15
+ """
16
+
17
+ def __init__(self):
18
+ self.client = get_redis_client()
19
+ if self.client:
20
+ logger.info("RedisMemoryBackend initialized.")
21
+ else:
22
+ logger.warning("REDIS_URL not configured. RedisBackend disabled.")
23
+
24
+ # ---------------------------------------------------------
25
+ # Core Operations
26
+ # ---------------------------------------------------------
27
+ def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None):
28
+ if not self.client:
29
+ return
30
+
31
+ value_json = json.dumps(value)
32
+ if ttl_seconds:
33
+ self.client.setex(key, ttl_seconds, value_json)
34
+ else:
35
+ self.client.set(key, value_json)
36
+
37
+ logger.info(f"[Redis] Set key '{key}'")
38
+
39
+ def get(self, key: str) -> Optional[Any]:
40
+ if not self.client:
41
+ return None
42
+
43
+ val = self.client.get(key)
44
+ if val is None:
45
+ return None
46
+
47
+ return json.loads(val)
48
+
49
+ def delete(self, key: str):
50
+ if self.client:
51
+ self.client.delete(key)
52
+ logger.info(f"[Redis] Deleted key '{key}'")
53
+
54
+ def keys(self) -> List[str]:
55
+ if not self.client:
56
+ return []
57
+ return [k for k in self.client.keys("*")]
58
+
59
+ # ---------------------------------------------------------
60
+ # Session support via prefixing (session:<id>:key)
61
+ # ---------------------------------------------------------
62
+ def set_session(self, session_id: str, key: str, value: Any, ttl: Optional[int] = None):
63
+ full_key = f"session:{session_id}:{key}"
64
+ self.set(full_key, value, ttl)
65
+
66
+ def get_session(self, session_id: str) -> Dict[str, Any]:
67
+ if not self.client:
68
+ return {}
69
+
70
+ pattern = f"session:{session_id}:*"
71
+ results = {}
72
+ for key in self.client.keys(pattern):
73
+ clean_key = key.decode().split(":", 2)[2]
74
+ results[clean_key] = self.get(key)
75
+ return results
76
+
77
+ def delete_session(self, session_id: str):
78
+ if not self.client:
79
+ return
80
+
81
+ pattern = f"session:{session_id}:*"
82
+ for key in self.client.keys(pattern):
83
+ self.client.delete(key)
84
+
85
+ logger.info(f"[Redis] Deleted session '{session_id}'")
@@ -0,0 +1,97 @@
1
+ # engines/safety/mongo_backend.py
2
+
3
+ import re
4
+ import time
5
+ from typing import List, Dict, Any, Optional
6
+
7
+ from orbmem.db.mongo import get_mongo_client
8
+ from orbmem.utils.logger import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ class SafetyEvent:
14
+ """
15
+ Represents a detected safety violation.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ text: str,
21
+ tag: str,
22
+ severity: float,
23
+ correction: Optional[str] = None,
24
+ metadata: Optional[Dict[str, Any]] = None,
25
+ ):
26
+ self.text = text
27
+ self.tag = tag
28
+ self.severity = severity
29
+ self.correction = correction or None
30
+ self.details = metadata or {}
31
+ self.timestamp = time.time()
32
+
33
+ def to_dict(self):
34
+ return {
35
+ "text": self.text,
36
+ "tag": self.tag,
37
+ "severity": self.severity,
38
+ "correction": self.correction,
39
+ "details": self.details,
40
+ "timestamp": self.timestamp,
41
+ }
42
+
43
+
44
+ class MongoSafetyBackend:
45
+ """
46
+ Scans text for unsafe content and logs events to MongoDB.
47
+ """
48
+
49
+ DEFAULT_PATTERNS = {
50
+ "self_harm": re.compile(r"(suicide|kill myself|hurt myself)", re.IGNORECASE),
51
+ "violence": re.compile(r"(kill|shoot|stab|attack)", re.IGNORECASE),
52
+ "hate": re.compile(r"(racial slur|hate\s+speech|bigot)", re.IGNORECASE),
53
+ "privacy": re.compile(r"(password|otp|aadhaar|credit card)", re.IGNORECASE),
54
+ }
55
+
56
+ def __init__(self):
57
+ self.client = get_mongo_client()
58
+ if self.client:
59
+ self.collection = self.client["ocdb"]["safety_events"]
60
+ logger.info("MongoSafetyBackend initialized.")
61
+ else:
62
+ self.collection = None
63
+ logger.warning("MongoSafetyBackend disabled (no MongoDB).")
64
+
65
+ # ---------------------------------------------------------
66
+ # Scoring model
67
+ # ---------------------------------------------------------
68
+ def _severity(self, tag: str, text: str) -> float:
69
+ base = {
70
+ "self_harm": 0.9,
71
+ "violence": 0.7,
72
+ "hate": 0.8,
73
+ "privacy": 0.6,
74
+ }.get(tag, 0.5)
75
+
76
+ length_factor = min(len(text) / 200, 1.0)
77
+ return round(base * length_factor, 3)
78
+
79
+ # ---------------------------------------------------------
80
+ # Main scan function
81
+ # ---------------------------------------------------------
82
+ def scan(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[SafetyEvent]:
83
+ if not text:
84
+ return []
85
+
86
+ events = []
87
+
88
+ for tag, pattern in self.DEFAULT_PATTERNS.items():
89
+ if pattern.search(text):
90
+ severity = self._severity(tag, text)
91
+ evt = SafetyEvent(text, tag, severity, metadata=metadata)
92
+ events.append(evt)
93
+
94
+ if self.collection:
95
+ self.collection.insert_one(evt.to_dict())
96
+
97
+ return events
@@ -0,0 +1,36 @@
1
+ # engines/safety/timeseries_backend.py
2
+
3
+ import time
4
+ from typing import Dict, List
5
+
6
+ from orbmem.utils.logger import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ class TimeSeriesSafetyBackend:
12
+ """
13
+ Stores a simple in-memory safety fingerprint time series.
14
+ In production, this could be replaced with TimescaleDB, InfluxDB, etc.
15
+ """
16
+
17
+ def __init__(self):
18
+ self.store: Dict[str, List[Dict]] = {}
19
+ logger.info("TimeSeriesSafetyBackend initialized.")
20
+
21
+ def add_point(self, tag: str, severity: float):
22
+ """
23
+ Adds a timestamped event to the timeseries.
24
+ """
25
+ if tag not in self.store:
26
+ self.store[tag] = []
27
+
28
+ self.store[tag].append({
29
+ "timestamp": time.time(),
30
+ "score": severity
31
+ })
32
+
33
+ logger.info(f"Added safety point: tag={tag}, severity={severity}")
34
+
35
+ def get_series(self, tag: str):
36
+ return self.store.get(tag, [])
@@ -0,0 +1,75 @@
1
+ # engines/vector/qdrant_backend.py
2
+ # FAISS fallback vector engine (no Qdrant required)
3
+
4
+ import faiss
5
+ import numpy as np
6
+ from typing import List, Dict, Any
7
+ from ...utils.logger import get_logger
8
+ from ...utils.embeddings import embed_text
9
+ from orbmem.utils.exceptions import DatabaseError
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class QdrantVectorBackend:
15
+ """
16
+ Lightweight FAISS-based vector engine.
17
+ Behaves like Qdrant but stores vectors in-memory.
18
+ Fully compatible with OCDB vector API.
19
+ """
20
+
21
+ def __init__(self, dim: int = 384):
22
+ self.dim = dim
23
+
24
+ # FAISS L2 index for similarity search
25
+ self.index = faiss.IndexFlatL2(dim)
26
+
27
+ # Store payloads manually
28
+ self.payloads: List[Dict[str, Any]] = []
29
+
30
+ logger.info("FAISS vector engine initialized (Qdrant replacement).")
31
+
32
+ # ---------------------------------------------------------
33
+ # Insert text into vector DB
34
+ # ---------------------------------------------------------
35
+ def add_text(self, text: str, payload: Dict[str, Any]):
36
+ try:
37
+ vector = np.array([embed_text(text)], dtype="float32")
38
+
39
+ self.index.add(vector)
40
+ self.payloads.append(payload)
41
+
42
+ logger.info(f"Vector added for ID={payload.get('id')}")
43
+
44
+ except Exception as e:
45
+ logger.error(f"FAISS add_text error: {e}")
46
+ raise DatabaseError(str(e))
47
+
48
+ # ---------------------------------------------------------
49
+ # Search for similar vectors
50
+ # ---------------------------------------------------------
51
+ def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
52
+ try:
53
+ vector = np.array([embed_text(query)], dtype="float32")
54
+
55
+ # If nothing stored yet
56
+ if self.index.ntotal == 0:
57
+ return []
58
+
59
+ distances, indices = self.index.search(vector, k)
60
+
61
+ results = []
62
+ for dist, idx in zip(distances[0], indices[0]):
63
+ if idx == -1:
64
+ continue
65
+
66
+ results.append({
67
+ "score": float(dist),
68
+ "payload": self.payloads[idx]
69
+ })
70
+
71
+ return results
72
+
73
+ except Exception as e:
74
+ logger.error(f"FAISS search error: {e}")
75
+ raise DatabaseError(str(e))
@@ -0,0 +1,11 @@
1
+ # models/__init__.py
2
+
3
+ from .memory import MemoryRecord
4
+ from .safety import SafetyEvent
5
+ from .fingerprints import SafetyFingerprint
6
+
7
+ __all__ = [
8
+ "MemoryRecord",
9
+ "SafetyEvent",
10
+ "SafetyFingerprint",
11
+ ]
@@ -0,0 +1,13 @@
1
+ # models/fingerprints.py
2
+
3
+ from sqlalchemy import Column, Integer, String, Float, TIMESTAMP
4
+ from sqlalchemy.sql import func
5
+ from orbmem.db.postgres import Base
6
+
7
+ class SafetyFingerprint(Base):
8
+ __tablename__ = "safety_fingerprints"
9
+
10
+ id = Column(Integer, primary_key=True)
11
+ tag = Column(String, index=True, nullable=False)
12
+ score = Column(Float, nullable=False)
13
+ timestamp = Column(TIMESTAMP(timezone=True), server_default=func.now())
@@ -0,0 +1,15 @@
1
+ # models/memory.py
2
+
3
+ from sqlalchemy import Column, Integer, String, JSON, TIMESTAMP
4
+ from sqlalchemy.sql import func
5
+ from orbmem.db.postgres import Base
6
+
7
+ class MemoryRecord(Base):
8
+ __tablename__ = "memory_records"
9
+
10
+ id = Column(Integer, primary_key=True, index=True)
11
+ key = Column(String, unique=True, index=True, nullable=False)
12
+ session_id = Column(String, index=True, nullable=True)
13
+ value = Column(JSON, nullable=False)
14
+ created_at = Column(TIMESTAMP(timezone=True), server_default=func.now())
15
+ expires_at = Column(TIMESTAMP(timezone=True), nullable=True)
@@ -0,0 +1,19 @@
1
+ # models/safety.py
2
+
3
+ from sqlalchemy import Column, Integer, String, Float, JSON, TIMESTAMP
4
+ from sqlalchemy.sql import func
5
+ from orbmem.db.postgres import Base
6
+
7
+ class SafetyEvent(Base):
8
+ __tablename__ = "safety_events"
9
+
10
+ id = Column(Integer, primary_key=True)
11
+ text = Column(String, nullable=False)
12
+ tag = Column(String, index=True, nullable=False)
13
+ severity = Column(Float, nullable=False)
14
+ correction = Column(String, nullable=True)
15
+
16
+ # CHANGED: metadata → details
17
+ details = Column(JSON, nullable=True)
18
+
19
+ timestamp = Column(TIMESTAMP(timezone=True), server_default=func.now())
@@ -0,0 +1,34 @@
1
+ # utils/embeddings.py
2
+
3
+ from functools import lru_cache
4
+ from sentence_transformers import SentenceTransformer
5
+ import numpy as np
6
+ from orbmem.utils.logger import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ @lru_cache(maxsize=1)
12
+ def get_embedding_model():
13
+ """
14
+ Loads the embedding model once.
15
+ MiniLM-L6 model → 384-dimensional vectors.
16
+ """
17
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
18
+ logger.info(f"Loading embedding model: {model_name}")
19
+ return SentenceTransformer(model_name)
20
+
21
+
22
+ def embed_text(text: str):
23
+ """
24
+ Convert text into a 384-d embedding vector.
25
+ Returns a Python list of floats (JSON serializable).
26
+ """
27
+ if not text or not isinstance(text, str):
28
+ return [0.0] * 384
29
+
30
+ model = get_embedding_model()
31
+ vector = model.encode(text)
32
+
33
+ # Convert NumPy array → list so Qdrant can store it
34
+ return vector.astype(float).tolist()
@@ -0,0 +1,30 @@
1
+ # utils/exceptions.py
2
+
3
+ class OCDBError(Exception):
4
+ """Base exception for all OCDB-related errors."""
5
+ pass
6
+
7
+
8
+ class ConfigError(OCDBError):
9
+ """Raised when configuration is missing or invalid."""
10
+ pass
11
+
12
+
13
+ class DatabaseError(OCDBError):
14
+ """Raised for any database connection or query issues."""
15
+ pass
16
+
17
+
18
+ class AuthError(OCDBError):
19
+ """Raised when authentication / API key validation fails."""
20
+ pass
21
+
22
+
23
+ class NotFoundError(OCDBError):
24
+ """Raised when a requested resource is not found."""
25
+ pass
26
+
27
+
28
+ class ValidationError(OCDBError):
29
+ """Raised when user input or payload is invalid."""
30
+ pass
@@ -0,0 +1,54 @@
1
+ # utils/helpers.py
2
+
3
+ import json
4
+ import time
5
+ from typing import Any, Dict
6
+
7
+
8
+ def now_ts() -> float:
9
+ """
10
+ Returns the current time as a Unix timestamp.
11
+ """
12
+ return time.time()
13
+
14
+
15
+ def safe_json(data: Any) -> str:
16
+ """
17
+ Convert Python dict/list/value → JSON string safely.
18
+ Used for logging or storing structured data.
19
+ """
20
+ try:
21
+ return json.dumps(data, ensure_ascii=False, default=str)
22
+ except Exception:
23
+ return "{}"
24
+
25
+
26
+ def deep_clean_dict(data: Dict[str, Any]) -> Dict[str, Any]:
27
+ """
28
+ Remove None values from dict recursively.
29
+ Useful for preparing clean API responses.
30
+ """
31
+ if not isinstance(data, dict):
32
+ return data
33
+
34
+ cleaned = {}
35
+ for k, v in data.items():
36
+ if v is None:
37
+ continue
38
+ if isinstance(v, dict):
39
+ cleaned[k] = deep_clean_dict(v)
40
+ else:
41
+ cleaned[k] = v
42
+
43
+ return cleaned
44
+
45
+
46
+ def ensure_str(value: Any) -> str:
47
+ """
48
+ Safely convert value → string.
49
+ Avoids crashes during logging or database operations.
50
+ """
51
+ try:
52
+ return str(value)
53
+ except Exception:
54
+ return "<unprintable>"
orbmem/utils/logger.py ADDED
@@ -0,0 +1,38 @@
1
+ # utils/logger.py
2
+
3
+ import logging
4
+ import sys
5
+ from typing import Optional
6
+
7
+
8
+ def _create_handler() -> logging.Handler:
9
+ """Create a console handler with a clean, simple format."""
10
+ handler = logging.StreamHandler(sys.stdout)
11
+ formatter = logging.Formatter(
12
+ "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
13
+ datefmt="%Y-%m-%d %H:%M:%S",
14
+ )
15
+ handler.setFormatter(formatter)
16
+ return handler
17
+
18
+
19
+ def get_logger(name: Optional[str] = None) -> logging.Logger:
20
+ """
21
+ Get a configured logger instance.
22
+
23
+ Usage:
24
+ from utils.logger import get_logger
25
+ logger = get_logger(__name__)
26
+ logger.info("Hello from OCDB")
27
+ """
28
+ logger_name = name or "ocdb"
29
+ logger = logging.getLogger(logger_name)
30
+
31
+ # Avoid adding multiple handlers if called many times
32
+ if not logger.handlers:
33
+ handler = _create_handler()
34
+ logger.addHandler(handler)
35
+ logger.setLevel(logging.INFO)
36
+ logger.propagate = False
37
+
38
+ return logger
@@ -0,0 +1,53 @@
1
+ # utils/validators.py
2
+
3
+ import re
4
+ from typing import Any, Dict
5
+
6
+
7
+ class ValidationError(Exception):
8
+ """Custom validation error."""
9
+ pass
10
+
11
+
12
+ def validate_non_empty(value: Any, field_name: str = "value"):
13
+ """Ensure a value is not empty or None."""
14
+ if value is None or (isinstance(value, str) and value.strip() == ""):
15
+ raise ValidationError(f"{field_name} cannot be empty.")
16
+ return value
17
+
18
+
19
+ def validate_dict(value: Any, field_name: str = "value"):
20
+ """Ensure the input is a dictionary."""
21
+ if not isinstance(value, dict):
22
+ raise ValidationError(f"{field_name} must be a dictionary.")
23
+ return value
24
+
25
+
26
+ def validate_key_in_dict(data: Dict, key: str):
27
+ """Ensure a key exists in a dictionary."""
28
+ if key not in data:
29
+ raise ValidationError(f"Missing required key: '{key}'")
30
+ return data[key]
31
+
32
+
33
+ def validate_api_key(key: str):
34
+ """Simple API key format validator."""
35
+ validate_non_empty(key, "API key")
36
+
37
+ # Example: basic 32-char hex key (you can change this pattern)
38
+ pattern = r"^[A-Fa-f0-9]{32}$"
39
+
40
+ if not re.match(pattern, key):
41
+ raise ValidationError("Invalid API key format.")
42
+
43
+ return key
44
+
45
+
46
+ def validate_memory_id(mem_id: str):
47
+ """Ensure memory ID follows a sane format."""
48
+ validate_non_empty(mem_id, "Memory ID")
49
+
50
+ if not re.match(r"^[A-Za-z0-9_\-]+$", mem_id):
51
+ raise ValidationError("Memory ID contains invalid characters.")
52
+
53
+ return mem_id