orbmem 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbmem/__init__.py +3 -0
- orbmem/core/auth.py +26 -0
- orbmem/core/config.py +123 -0
- orbmem/core/ocdb.py +126 -0
- orbmem/db/mongo.py +25 -0
- orbmem/db/neo4j.py +31 -0
- orbmem/db/postgres.py +36 -0
- orbmem/db/redis.py +25 -0
- orbmem/engines/base_engine.py +47 -0
- orbmem/engines/graph/neo4j_backend.py +70 -0
- orbmem/engines/memory/postgres_backend.py +112 -0
- orbmem/engines/memory/redis_backend.py +85 -0
- orbmem/engines/safety/mongo_backend.py +97 -0
- orbmem/engines/safety/timeseries_backend.py +36 -0
- orbmem/engines/vector/qdrant_backend.py +75 -0
- orbmem/models/__init__.py +11 -0
- orbmem/models/fingerprints.py +13 -0
- orbmem/models/memory.py +15 -0
- orbmem/models/safety.py +19 -0
- orbmem/utils/embeddings.py +34 -0
- orbmem/utils/exceptions.py +30 -0
- orbmem/utils/helpers.py +54 -0
- orbmem/utils/logger.py +38 -0
- orbmem/utils/validators.py +53 -0
- orbmem-1.0.4.dist-info/METADATA +177 -0
- orbmem-1.0.4.dist-info/RECORD +29 -0
- orbmem-1.0.4.dist-info/WHEEL +5 -0
- orbmem-1.0.4.dist-info/licenses/LICENSE +23 -0
- orbmem-1.0.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# engines/memory/redis_backend.py
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any, Dict, Optional, List
|
|
5
|
+
from orbmem.db.redis import get_redis_client
|
|
6
|
+
from orbmem.utils.logger import get_logger
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RedisMemoryBackend:
|
|
12
|
+
"""
|
|
13
|
+
High-speed TTL memory using Redis.
|
|
14
|
+
Best for short-term memory, active sessions, caching.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
self.client = get_redis_client()
|
|
19
|
+
if self.client:
|
|
20
|
+
logger.info("RedisMemoryBackend initialized.")
|
|
21
|
+
else:
|
|
22
|
+
logger.warning("REDIS_URL not configured. RedisBackend disabled.")
|
|
23
|
+
|
|
24
|
+
# ---------------------------------------------------------
|
|
25
|
+
# Core Operations
|
|
26
|
+
# ---------------------------------------------------------
|
|
27
|
+
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None):
|
|
28
|
+
if not self.client:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
value_json = json.dumps(value)
|
|
32
|
+
if ttl_seconds:
|
|
33
|
+
self.client.setex(key, ttl_seconds, value_json)
|
|
34
|
+
else:
|
|
35
|
+
self.client.set(key, value_json)
|
|
36
|
+
|
|
37
|
+
logger.info(f"[Redis] Set key '{key}'")
|
|
38
|
+
|
|
39
|
+
def get(self, key: str) -> Optional[Any]:
|
|
40
|
+
if not self.client:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
val = self.client.get(key)
|
|
44
|
+
if val is None:
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
return json.loads(val)
|
|
48
|
+
|
|
49
|
+
def delete(self, key: str):
|
|
50
|
+
if self.client:
|
|
51
|
+
self.client.delete(key)
|
|
52
|
+
logger.info(f"[Redis] Deleted key '{key}'")
|
|
53
|
+
|
|
54
|
+
def keys(self) -> List[str]:
|
|
55
|
+
if not self.client:
|
|
56
|
+
return []
|
|
57
|
+
return [k for k in self.client.keys("*")]
|
|
58
|
+
|
|
59
|
+
# ---------------------------------------------------------
|
|
60
|
+
# Session support via prefixing (session:<id>:key)
|
|
61
|
+
# ---------------------------------------------------------
|
|
62
|
+
def set_session(self, session_id: str, key: str, value: Any, ttl: Optional[int] = None):
|
|
63
|
+
full_key = f"session:{session_id}:{key}"
|
|
64
|
+
self.set(full_key, value, ttl)
|
|
65
|
+
|
|
66
|
+
def get_session(self, session_id: str) -> Dict[str, Any]:
|
|
67
|
+
if not self.client:
|
|
68
|
+
return {}
|
|
69
|
+
|
|
70
|
+
pattern = f"session:{session_id}:*"
|
|
71
|
+
results = {}
|
|
72
|
+
for key in self.client.keys(pattern):
|
|
73
|
+
clean_key = key.decode().split(":", 2)[2]
|
|
74
|
+
results[clean_key] = self.get(key)
|
|
75
|
+
return results
|
|
76
|
+
|
|
77
|
+
def delete_session(self, session_id: str):
|
|
78
|
+
if not self.client:
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
pattern = f"session:{session_id}:*"
|
|
82
|
+
for key in self.client.keys(pattern):
|
|
83
|
+
self.client.delete(key)
|
|
84
|
+
|
|
85
|
+
logger.info(f"[Redis] Deleted session '{session_id}'")
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# engines/safety/mongo_backend.py
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
import time
|
|
5
|
+
from typing import List, Dict, Any, Optional
|
|
6
|
+
|
|
7
|
+
from orbmem.db.mongo import get_mongo_client
|
|
8
|
+
from orbmem.utils.logger import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SafetyEvent:
|
|
14
|
+
"""
|
|
15
|
+
Represents a detected safety violation.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
text: str,
|
|
21
|
+
tag: str,
|
|
22
|
+
severity: float,
|
|
23
|
+
correction: Optional[str] = None,
|
|
24
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
25
|
+
):
|
|
26
|
+
self.text = text
|
|
27
|
+
self.tag = tag
|
|
28
|
+
self.severity = severity
|
|
29
|
+
self.correction = correction or None
|
|
30
|
+
self.details = metadata or {}
|
|
31
|
+
self.timestamp = time.time()
|
|
32
|
+
|
|
33
|
+
def to_dict(self):
|
|
34
|
+
return {
|
|
35
|
+
"text": self.text,
|
|
36
|
+
"tag": self.tag,
|
|
37
|
+
"severity": self.severity,
|
|
38
|
+
"correction": self.correction,
|
|
39
|
+
"details": self.details,
|
|
40
|
+
"timestamp": self.timestamp,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class MongoSafetyBackend:
|
|
45
|
+
"""
|
|
46
|
+
Scans text for unsafe content and logs events to MongoDB.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
DEFAULT_PATTERNS = {
|
|
50
|
+
"self_harm": re.compile(r"(suicide|kill myself|hurt myself)", re.IGNORECASE),
|
|
51
|
+
"violence": re.compile(r"(kill|shoot|stab|attack)", re.IGNORECASE),
|
|
52
|
+
"hate": re.compile(r"(racial slur|hate\s+speech|bigot)", re.IGNORECASE),
|
|
53
|
+
"privacy": re.compile(r"(password|otp|aadhaar|credit card)", re.IGNORECASE),
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
def __init__(self):
|
|
57
|
+
self.client = get_mongo_client()
|
|
58
|
+
if self.client:
|
|
59
|
+
self.collection = self.client["ocdb"]["safety_events"]
|
|
60
|
+
logger.info("MongoSafetyBackend initialized.")
|
|
61
|
+
else:
|
|
62
|
+
self.collection = None
|
|
63
|
+
logger.warning("MongoSafetyBackend disabled (no MongoDB).")
|
|
64
|
+
|
|
65
|
+
# ---------------------------------------------------------
|
|
66
|
+
# Scoring model
|
|
67
|
+
# ---------------------------------------------------------
|
|
68
|
+
def _severity(self, tag: str, text: str) -> float:
|
|
69
|
+
base = {
|
|
70
|
+
"self_harm": 0.9,
|
|
71
|
+
"violence": 0.7,
|
|
72
|
+
"hate": 0.8,
|
|
73
|
+
"privacy": 0.6,
|
|
74
|
+
}.get(tag, 0.5)
|
|
75
|
+
|
|
76
|
+
length_factor = min(len(text) / 200, 1.0)
|
|
77
|
+
return round(base * length_factor, 3)
|
|
78
|
+
|
|
79
|
+
# ---------------------------------------------------------
|
|
80
|
+
# Main scan function
|
|
81
|
+
# ---------------------------------------------------------
|
|
82
|
+
def scan(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[SafetyEvent]:
|
|
83
|
+
if not text:
|
|
84
|
+
return []
|
|
85
|
+
|
|
86
|
+
events = []
|
|
87
|
+
|
|
88
|
+
for tag, pattern in self.DEFAULT_PATTERNS.items():
|
|
89
|
+
if pattern.search(text):
|
|
90
|
+
severity = self._severity(tag, text)
|
|
91
|
+
evt = SafetyEvent(text, tag, severity, metadata=metadata)
|
|
92
|
+
events.append(evt)
|
|
93
|
+
|
|
94
|
+
if self.collection:
|
|
95
|
+
self.collection.insert_one(evt.to_dict())
|
|
96
|
+
|
|
97
|
+
return events
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# engines/safety/timeseries_backend.py
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from typing import Dict, List
|
|
5
|
+
|
|
6
|
+
from orbmem.utils.logger import get_logger
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TimeSeriesSafetyBackend:
|
|
12
|
+
"""
|
|
13
|
+
Stores a simple in-memory safety fingerprint time series.
|
|
14
|
+
In production, this could be replaced with TimescaleDB, InfluxDB, etc.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
self.store: Dict[str, List[Dict]] = {}
|
|
19
|
+
logger.info("TimeSeriesSafetyBackend initialized.")
|
|
20
|
+
|
|
21
|
+
def add_point(self, tag: str, severity: float):
|
|
22
|
+
"""
|
|
23
|
+
Adds a timestamped event to the timeseries.
|
|
24
|
+
"""
|
|
25
|
+
if tag not in self.store:
|
|
26
|
+
self.store[tag] = []
|
|
27
|
+
|
|
28
|
+
self.store[tag].append({
|
|
29
|
+
"timestamp": time.time(),
|
|
30
|
+
"score": severity
|
|
31
|
+
})
|
|
32
|
+
|
|
33
|
+
logger.info(f"Added safety point: tag={tag}, severity={severity}")
|
|
34
|
+
|
|
35
|
+
def get_series(self, tag: str):
|
|
36
|
+
return self.store.get(tag, [])
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# engines/vector/qdrant_backend.py
|
|
2
|
+
# FAISS fallback vector engine (no Qdrant required)
|
|
3
|
+
|
|
4
|
+
import faiss
|
|
5
|
+
import numpy as np
|
|
6
|
+
from typing import List, Dict, Any
|
|
7
|
+
from ...utils.logger import get_logger
|
|
8
|
+
from ...utils.embeddings import embed_text
|
|
9
|
+
from orbmem.utils.exceptions import DatabaseError
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class QdrantVectorBackend:
|
|
15
|
+
"""
|
|
16
|
+
Lightweight FAISS-based vector engine.
|
|
17
|
+
Behaves like Qdrant but stores vectors in-memory.
|
|
18
|
+
Fully compatible with OCDB vector API.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, dim: int = 384):
|
|
22
|
+
self.dim = dim
|
|
23
|
+
|
|
24
|
+
# FAISS L2 index for similarity search
|
|
25
|
+
self.index = faiss.IndexFlatL2(dim)
|
|
26
|
+
|
|
27
|
+
# Store payloads manually
|
|
28
|
+
self.payloads: List[Dict[str, Any]] = []
|
|
29
|
+
|
|
30
|
+
logger.info("FAISS vector engine initialized (Qdrant replacement).")
|
|
31
|
+
|
|
32
|
+
# ---------------------------------------------------------
|
|
33
|
+
# Insert text into vector DB
|
|
34
|
+
# ---------------------------------------------------------
|
|
35
|
+
def add_text(self, text: str, payload: Dict[str, Any]):
|
|
36
|
+
try:
|
|
37
|
+
vector = np.array([embed_text(text)], dtype="float32")
|
|
38
|
+
|
|
39
|
+
self.index.add(vector)
|
|
40
|
+
self.payloads.append(payload)
|
|
41
|
+
|
|
42
|
+
logger.info(f"Vector added for ID={payload.get('id')}")
|
|
43
|
+
|
|
44
|
+
except Exception as e:
|
|
45
|
+
logger.error(f"FAISS add_text error: {e}")
|
|
46
|
+
raise DatabaseError(str(e))
|
|
47
|
+
|
|
48
|
+
# ---------------------------------------------------------
|
|
49
|
+
# Search for similar vectors
|
|
50
|
+
# ---------------------------------------------------------
|
|
51
|
+
def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
|
52
|
+
try:
|
|
53
|
+
vector = np.array([embed_text(query)], dtype="float32")
|
|
54
|
+
|
|
55
|
+
# If nothing stored yet
|
|
56
|
+
if self.index.ntotal == 0:
|
|
57
|
+
return []
|
|
58
|
+
|
|
59
|
+
distances, indices = self.index.search(vector, k)
|
|
60
|
+
|
|
61
|
+
results = []
|
|
62
|
+
for dist, idx in zip(distances[0], indices[0]):
|
|
63
|
+
if idx == -1:
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
results.append({
|
|
67
|
+
"score": float(dist),
|
|
68
|
+
"payload": self.payloads[idx]
|
|
69
|
+
})
|
|
70
|
+
|
|
71
|
+
return results
|
|
72
|
+
|
|
73
|
+
except Exception as e:
|
|
74
|
+
logger.error(f"FAISS search error: {e}")
|
|
75
|
+
raise DatabaseError(str(e))
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# models/fingerprints.py
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import Column, Integer, String, Float, TIMESTAMP
|
|
4
|
+
from sqlalchemy.sql import func
|
|
5
|
+
from orbmem.db.postgres import Base
|
|
6
|
+
|
|
7
|
+
class SafetyFingerprint(Base):
|
|
8
|
+
__tablename__ = "safety_fingerprints"
|
|
9
|
+
|
|
10
|
+
id = Column(Integer, primary_key=True)
|
|
11
|
+
tag = Column(String, index=True, nullable=False)
|
|
12
|
+
score = Column(Float, nullable=False)
|
|
13
|
+
timestamp = Column(TIMESTAMP(timezone=True), server_default=func.now())
|
orbmem/models/memory.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# models/memory.py
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import Column, Integer, String, JSON, TIMESTAMP
|
|
4
|
+
from sqlalchemy.sql import func
|
|
5
|
+
from orbmem.db.postgres import Base
|
|
6
|
+
|
|
7
|
+
class MemoryRecord(Base):
|
|
8
|
+
__tablename__ = "memory_records"
|
|
9
|
+
|
|
10
|
+
id = Column(Integer, primary_key=True, index=True)
|
|
11
|
+
key = Column(String, unique=True, index=True, nullable=False)
|
|
12
|
+
session_id = Column(String, index=True, nullable=True)
|
|
13
|
+
value = Column(JSON, nullable=False)
|
|
14
|
+
created_at = Column(TIMESTAMP(timezone=True), server_default=func.now())
|
|
15
|
+
expires_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
orbmem/models/safety.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# models/safety.py
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import Column, Integer, String, Float, JSON, TIMESTAMP
|
|
4
|
+
from sqlalchemy.sql import func
|
|
5
|
+
from orbmem.db.postgres import Base
|
|
6
|
+
|
|
7
|
+
class SafetyEvent(Base):
|
|
8
|
+
__tablename__ = "safety_events"
|
|
9
|
+
|
|
10
|
+
id = Column(Integer, primary_key=True)
|
|
11
|
+
text = Column(String, nullable=False)
|
|
12
|
+
tag = Column(String, index=True, nullable=False)
|
|
13
|
+
severity = Column(Float, nullable=False)
|
|
14
|
+
correction = Column(String, nullable=True)
|
|
15
|
+
|
|
16
|
+
# CHANGED: metadata → details
|
|
17
|
+
details = Column(JSON, nullable=True)
|
|
18
|
+
|
|
19
|
+
timestamp = Column(TIMESTAMP(timezone=True), server_default=func.now())
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# utils/embeddings.py
|
|
2
|
+
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from sentence_transformers import SentenceTransformer
|
|
5
|
+
import numpy as np
|
|
6
|
+
from orbmem.utils.logger import get_logger
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@lru_cache(maxsize=1)
|
|
12
|
+
def get_embedding_model():
|
|
13
|
+
"""
|
|
14
|
+
Loads the embedding model once.
|
|
15
|
+
MiniLM-L6 model → 384-dimensional vectors.
|
|
16
|
+
"""
|
|
17
|
+
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
|
18
|
+
logger.info(f"Loading embedding model: {model_name}")
|
|
19
|
+
return SentenceTransformer(model_name)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def embed_text(text: str):
|
|
23
|
+
"""
|
|
24
|
+
Convert text into a 384-d embedding vector.
|
|
25
|
+
Returns a Python list of floats (JSON serializable).
|
|
26
|
+
"""
|
|
27
|
+
if not text or not isinstance(text, str):
|
|
28
|
+
return [0.0] * 384
|
|
29
|
+
|
|
30
|
+
model = get_embedding_model()
|
|
31
|
+
vector = model.encode(text)
|
|
32
|
+
|
|
33
|
+
# Convert NumPy array → list so Qdrant can store it
|
|
34
|
+
return vector.astype(float).tolist()
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# utils/exceptions.py
|
|
2
|
+
|
|
3
|
+
class OCDBError(Exception):
|
|
4
|
+
"""Base exception for all OCDB-related errors."""
|
|
5
|
+
pass
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ConfigError(OCDBError):
|
|
9
|
+
"""Raised when configuration is missing or invalid."""
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatabaseError(OCDBError):
|
|
14
|
+
"""Raised for any database connection or query issues."""
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AuthError(OCDBError):
|
|
19
|
+
"""Raised when authentication / API key validation fails."""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class NotFoundError(OCDBError):
|
|
24
|
+
"""Raised when a requested resource is not found."""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ValidationError(OCDBError):
|
|
29
|
+
"""Raised when user input or payload is invalid."""
|
|
30
|
+
pass
|
orbmem/utils/helpers.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# utils/helpers.py
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
from typing import Any, Dict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def now_ts() -> float:
|
|
9
|
+
"""
|
|
10
|
+
Returns the current time as a Unix timestamp.
|
|
11
|
+
"""
|
|
12
|
+
return time.time()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def safe_json(data: Any) -> str:
|
|
16
|
+
"""
|
|
17
|
+
Convert Python dict/list/value → JSON string safely.
|
|
18
|
+
Used for logging or storing structured data.
|
|
19
|
+
"""
|
|
20
|
+
try:
|
|
21
|
+
return json.dumps(data, ensure_ascii=False, default=str)
|
|
22
|
+
except Exception:
|
|
23
|
+
return "{}"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def deep_clean_dict(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
27
|
+
"""
|
|
28
|
+
Remove None values from dict recursively.
|
|
29
|
+
Useful for preparing clean API responses.
|
|
30
|
+
"""
|
|
31
|
+
if not isinstance(data, dict):
|
|
32
|
+
return data
|
|
33
|
+
|
|
34
|
+
cleaned = {}
|
|
35
|
+
for k, v in data.items():
|
|
36
|
+
if v is None:
|
|
37
|
+
continue
|
|
38
|
+
if isinstance(v, dict):
|
|
39
|
+
cleaned[k] = deep_clean_dict(v)
|
|
40
|
+
else:
|
|
41
|
+
cleaned[k] = v
|
|
42
|
+
|
|
43
|
+
return cleaned
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def ensure_str(value: Any) -> str:
|
|
47
|
+
"""
|
|
48
|
+
Safely convert value → string.
|
|
49
|
+
Avoids crashes during logging or database operations.
|
|
50
|
+
"""
|
|
51
|
+
try:
|
|
52
|
+
return str(value)
|
|
53
|
+
except Exception:
|
|
54
|
+
return "<unprintable>"
|
orbmem/utils/logger.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# utils/logger.py
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _create_handler() -> logging.Handler:
|
|
9
|
+
"""Create a console handler with a clean, simple format."""
|
|
10
|
+
handler = logging.StreamHandler(sys.stdout)
|
|
11
|
+
formatter = logging.Formatter(
|
|
12
|
+
"[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
|
|
13
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
14
|
+
)
|
|
15
|
+
handler.setFormatter(formatter)
|
|
16
|
+
return handler
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
|
20
|
+
"""
|
|
21
|
+
Get a configured logger instance.
|
|
22
|
+
|
|
23
|
+
Usage:
|
|
24
|
+
from utils.logger import get_logger
|
|
25
|
+
logger = get_logger(__name__)
|
|
26
|
+
logger.info("Hello from OCDB")
|
|
27
|
+
"""
|
|
28
|
+
logger_name = name or "ocdb"
|
|
29
|
+
logger = logging.getLogger(logger_name)
|
|
30
|
+
|
|
31
|
+
# Avoid adding multiple handlers if called many times
|
|
32
|
+
if not logger.handlers:
|
|
33
|
+
handler = _create_handler()
|
|
34
|
+
logger.addHandler(handler)
|
|
35
|
+
logger.setLevel(logging.INFO)
|
|
36
|
+
logger.propagate = False
|
|
37
|
+
|
|
38
|
+
return logger
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# utils/validators.py
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ValidationError(Exception):
|
|
8
|
+
"""Custom validation error."""
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def validate_non_empty(value: Any, field_name: str = "value"):
|
|
13
|
+
"""Ensure a value is not empty or None."""
|
|
14
|
+
if value is None or (isinstance(value, str) and value.strip() == ""):
|
|
15
|
+
raise ValidationError(f"{field_name} cannot be empty.")
|
|
16
|
+
return value
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def validate_dict(value: Any, field_name: str = "value"):
|
|
20
|
+
"""Ensure the input is a dictionary."""
|
|
21
|
+
if not isinstance(value, dict):
|
|
22
|
+
raise ValidationError(f"{field_name} must be a dictionary.")
|
|
23
|
+
return value
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def validate_key_in_dict(data: Dict, key: str):
|
|
27
|
+
"""Ensure a key exists in a dictionary."""
|
|
28
|
+
if key not in data:
|
|
29
|
+
raise ValidationError(f"Missing required key: '{key}'")
|
|
30
|
+
return data[key]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def validate_api_key(key: str):
|
|
34
|
+
"""Simple API key format validator."""
|
|
35
|
+
validate_non_empty(key, "API key")
|
|
36
|
+
|
|
37
|
+
# Example: basic 32-char hex key (you can change this pattern)
|
|
38
|
+
pattern = r"^[A-Fa-f0-9]{32}$"
|
|
39
|
+
|
|
40
|
+
if not re.match(pattern, key):
|
|
41
|
+
raise ValidationError("Invalid API key format.")
|
|
42
|
+
|
|
43
|
+
return key
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def validate_memory_id(mem_id: str):
|
|
47
|
+
"""Ensure memory ID follows a sane format."""
|
|
48
|
+
validate_non_empty(mem_id, "Memory ID")
|
|
49
|
+
|
|
50
|
+
if not re.match(r"^[A-Za-z0-9_\-]+$", mem_id):
|
|
51
|
+
raise ValidationError("Memory ID contains invalid characters.")
|
|
52
|
+
|
|
53
|
+
return mem_id
|