hammad-python 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hammad_python-0.0.15.dist-info/METADATA +184 -0
- hammad_python-0.0.15.dist-info/RECORD +4 -0
- hammad/__init__.py +0 -180
- hammad/_core/__init__.py +0 -1
- hammad/_core/_utils/__init__.py +0 -4
- hammad/_core/_utils/_import_utils.py +0 -182
- hammad/ai/__init__.py +0 -59
- hammad/ai/_utils.py +0 -142
- hammad/ai/completions/__init__.py +0 -44
- hammad/ai/completions/client.py +0 -729
- hammad/ai/completions/create.py +0 -686
- hammad/ai/completions/types.py +0 -711
- hammad/ai/completions/utils.py +0 -374
- hammad/ai/embeddings/__init__.py +0 -35
- hammad/ai/embeddings/client/__init__.py +0 -1
- hammad/ai/embeddings/client/base_embeddings_client.py +0 -26
- hammad/ai/embeddings/client/fastembed_text_embeddings_client.py +0 -200
- hammad/ai/embeddings/client/litellm_embeddings_client.py +0 -288
- hammad/ai/embeddings/create.py +0 -159
- hammad/ai/embeddings/types.py +0 -69
- hammad/base/__init__.py +0 -35
- hammad/base/fields.py +0 -546
- hammad/base/model.py +0 -1078
- hammad/base/utils.py +0 -280
- hammad/cache/__init__.py +0 -48
- hammad/cache/base_cache.py +0 -181
- hammad/cache/cache.py +0 -169
- hammad/cache/decorators.py +0 -261
- hammad/cache/file_cache.py +0 -80
- hammad/cache/ttl_cache.py +0 -74
- hammad/cli/__init__.py +0 -33
- hammad/cli/animations.py +0 -604
- hammad/cli/plugins.py +0 -781
- hammad/cli/styles/__init__.py +0 -55
- hammad/cli/styles/settings.py +0 -139
- hammad/cli/styles/types.py +0 -358
- hammad/cli/styles/utils.py +0 -480
- hammad/configuration/__init__.py +0 -35
- hammad/configuration/configuration.py +0 -564
- hammad/data/__init__.py +0 -39
- hammad/data/collections/__init__.py +0 -34
- hammad/data/collections/base_collection.py +0 -58
- hammad/data/collections/collection.py +0 -452
- hammad/data/collections/searchable_collection.py +0 -556
- hammad/data/collections/vector_collection.py +0 -603
- hammad/data/databases/__init__.py +0 -21
- hammad/data/databases/database.py +0 -902
- hammad/json/__init__.py +0 -21
- hammad/json/converters.py +0 -152
- hammad/logging/__init__.py +0 -35
- hammad/logging/decorators.py +0 -834
- hammad/logging/logger.py +0 -954
- hammad/multimodal/__init__.py +0 -24
- hammad/multimodal/audio.py +0 -96
- hammad/multimodal/image.py +0 -80
- hammad/multithreading/__init__.py +0 -304
- hammad/py.typed +0 -0
- hammad/pydantic/__init__.py +0 -43
- hammad/pydantic/converters.py +0 -623
- hammad/pydantic/models/__init__.py +0 -28
- hammad/pydantic/models/arbitrary_model.py +0 -46
- hammad/pydantic/models/cacheable_model.py +0 -79
- hammad/pydantic/models/fast_model.py +0 -318
- hammad/pydantic/models/function_model.py +0 -176
- hammad/pydantic/models/subscriptable_model.py +0 -63
- hammad/text/__init__.py +0 -82
- hammad/text/converters.py +0 -723
- hammad/text/markdown.py +0 -131
- hammad/text/text.py +0 -1066
- hammad/types/__init__.py +0 -11
- hammad/types/file.py +0 -358
- hammad/typing/__init__.py +0 -407
- hammad/web/__init__.py +0 -43
- hammad/web/http/__init__.py +0 -1
- hammad/web/http/client.py +0 -944
- hammad/web/models.py +0 -245
- hammad/web/openapi/__init__.py +0 -0
- hammad/web/openapi/client.py +0 -740
- hammad/web/search/__init__.py +0 -1
- hammad/web/search/client.py +0 -988
- hammad/web/utils.py +0 -472
- hammad/yaml/__init__.py +0 -30
- hammad/yaml/converters.py +0 -19
- hammad_python-0.0.13.dist-info/METADATA +0 -38
- hammad_python-0.0.13.dist-info/RECORD +0 -85
- {hammad_python-0.0.13.dist-info → hammad_python-0.0.15.dist-info}/WHEEL +0 -0
- {hammad_python-0.0.13.dist-info → hammad_python-0.0.15.dist-info}/licenses/LICENSE +0 -0
@@ -1,603 +0,0 @@
|
|
1
|
-
"""hammad.data.collections.vector_collection"""
|
2
|
-
|
3
|
-
import uuid
|
4
|
-
from typing import Any, Dict, Optional, List, Generic, Union, Callable
|
5
|
-
from datetime import datetime, timezone, timedelta
|
6
|
-
|
7
|
-
try:
|
8
|
-
from qdrant_client import QdrantClient
|
9
|
-
from qdrant_client.models import (
|
10
|
-
Distance,
|
11
|
-
VectorParams,
|
12
|
-
PointStruct,
|
13
|
-
Filter,
|
14
|
-
FieldCondition,
|
15
|
-
MatchValue,
|
16
|
-
SearchRequest,
|
17
|
-
QueryResponse,
|
18
|
-
)
|
19
|
-
import numpy as np
|
20
|
-
except ImportError as e:
|
21
|
-
raise ImportError(
|
22
|
-
"qdrant-client is required for VectorCollection. "
|
23
|
-
"Install with: pip install qdrant-client"
|
24
|
-
"Or install the the `ai` extra: `pip install hammad-python[ai]`"
|
25
|
-
) from e
|
26
|
-
|
27
|
-
from .base_collection import BaseCollection, Object, Filters, Schema
|
28
|
-
from ...ai.embeddings.create import (
|
29
|
-
create_embeddings,
|
30
|
-
async_create_embeddings,
|
31
|
-
)
|
32
|
-
from ...ai.embeddings.client.fastembed_text_embeddings_client import (
|
33
|
-
FastEmbedTextEmbeddingModel,
|
34
|
-
)
|
35
|
-
from ...ai.embeddings.client.litellm_embeddings_client import (
|
36
|
-
LiteLlmEmbeddingModel,
|
37
|
-
)
|
38
|
-
|
39
|
-
__all__ = ("VectorCollection",)
|
40
|
-
|
41
|
-
|
42
|
-
class VectorCollection(BaseCollection, Generic[Object]):
|
43
|
-
"""
|
44
|
-
Vector collection class that uses Qdrant for vector storage and similarity search.
|
45
|
-
|
46
|
-
This provides vector-based functionality for storing embeddings and performing
|
47
|
-
semantic similarity searches.
|
48
|
-
"""
|
49
|
-
|
50
|
-
# Namespace UUID for generating deterministic UUIDs from string IDs
|
51
|
-
_NAMESPACE_UUID = uuid.UUID("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
|
52
|
-
|
53
|
-
def __init__(
|
54
|
-
self,
|
55
|
-
name: str,
|
56
|
-
vector_size: int,
|
57
|
-
schema: Optional[Schema] = None,
|
58
|
-
default_ttl: Optional[int] = None,
|
59
|
-
storage_backend: Optional[Any] = None,
|
60
|
-
distance_metric: Distance = Distance.DOT,
|
61
|
-
qdrant_config: Optional[Dict[str, Any]] = None,
|
62
|
-
embedding_function: Optional[Callable[[Any], List[float]]] = None,
|
63
|
-
model: Optional[str] = None,
|
64
|
-
# Common embedding parameters
|
65
|
-
format: bool = False,
|
66
|
-
# LiteLLM parameters
|
67
|
-
dimensions: Optional[int] = None,
|
68
|
-
encoding_format: Optional[str] = None,
|
69
|
-
timeout: Optional[int] = None,
|
70
|
-
api_base: Optional[str] = None,
|
71
|
-
api_version: Optional[str] = None,
|
72
|
-
api_key: Optional[str] = None,
|
73
|
-
api_type: Optional[str] = None,
|
74
|
-
caching: bool = False,
|
75
|
-
user: Optional[str] = None,
|
76
|
-
# FastEmbed parameters
|
77
|
-
parallel: Optional[int] = None,
|
78
|
-
batch_size: Optional[int] = None,
|
79
|
-
):
|
80
|
-
"""
|
81
|
-
Initialize a vector collection.
|
82
|
-
|
83
|
-
Args:
|
84
|
-
name: The name of the collection
|
85
|
-
vector_size: Size/dimension of the vectors to store
|
86
|
-
schema: Optional schema for type validation
|
87
|
-
default_ttl: Default TTL for items in seconds
|
88
|
-
storage_backend: Optional storage backend (Database instance or custom)
|
89
|
-
distance_metric: Distance metric for similarity search (COSINE, DOT, EUCLID, MANHATTAN)
|
90
|
-
qdrant_config: Optional Qdrant configuration
|
91
|
-
Example: {
|
92
|
-
"path": "/path/to/db", # For persistent storage
|
93
|
-
"host": "localhost", # For remote Qdrant
|
94
|
-
"port": 6333,
|
95
|
-
"grpc_port": 6334,
|
96
|
-
"prefer_grpc": True,
|
97
|
-
"api_key": "your-api-key"
|
98
|
-
}
|
99
|
-
embedding_function: Optional function to convert objects to vectors
|
100
|
-
model: Optional model name (e.g., 'fastembed/BAAI/bge-small-en-v1.5', 'openai/text-embedding-3-small')
|
101
|
-
format: Whether to format each non-string input as a markdown string
|
102
|
-
|
103
|
-
# LiteLLM-specific parameters:
|
104
|
-
dimensions: The dimensions of the embedding
|
105
|
-
encoding_format: The encoding format of the embedding (e.g. "float", "base64")
|
106
|
-
timeout: The timeout for the embedding request
|
107
|
-
api_base: The base URL for the embedding API
|
108
|
-
api_version: The version of the embedding API
|
109
|
-
api_key: The API key for the embedding API
|
110
|
-
api_type: The type of the embedding API
|
111
|
-
caching: Whether to cache the embedding
|
112
|
-
user: The user for the embedding
|
113
|
-
|
114
|
-
# FastEmbed-specific parameters:
|
115
|
-
parallel: The number of parallel processes to use for the embedding
|
116
|
-
batch_size: The batch size to use for the embedding
|
117
|
-
"""
|
118
|
-
self.name = name
|
119
|
-
self.vector_size = vector_size
|
120
|
-
self.schema = schema
|
121
|
-
self.default_ttl = default_ttl
|
122
|
-
self.distance_metric = distance_metric
|
123
|
-
self._storage_backend = storage_backend
|
124
|
-
self._embedding_function = embedding_function
|
125
|
-
self._model = model
|
126
|
-
|
127
|
-
# Store embedding parameters
|
128
|
-
self._embedding_params = {
|
129
|
-
"format": format,
|
130
|
-
# LiteLLM parameters
|
131
|
-
"dimensions": dimensions,
|
132
|
-
"encoding_format": encoding_format,
|
133
|
-
"timeout": timeout,
|
134
|
-
"api_base": api_base,
|
135
|
-
"api_version": api_version,
|
136
|
-
"api_key": api_key,
|
137
|
-
"api_type": api_type,
|
138
|
-
"caching": caching,
|
139
|
-
"user": user,
|
140
|
-
# FastEmbed parameters
|
141
|
-
"parallel": parallel,
|
142
|
-
"batch_size": batch_size,
|
143
|
-
}
|
144
|
-
|
145
|
-
# If model is provided, create embedding function
|
146
|
-
if model:
|
147
|
-
self._embedding_function = self._create_embedding_function(model)
|
148
|
-
|
149
|
-
# Store qdrant configuration
|
150
|
-
self._qdrant_config = qdrant_config or {}
|
151
|
-
|
152
|
-
# In-memory storage when used independently
|
153
|
-
self._items: Dict[str, Dict[str, Any]] = {}
|
154
|
-
|
155
|
-
# Mapping from original IDs to UUIDs
|
156
|
-
self._id_mapping: Dict[str, str] = {}
|
157
|
-
|
158
|
-
# Initialize Qdrant client
|
159
|
-
self._init_qdrant_client()
|
160
|
-
|
161
|
-
def _create_embedding_function(
|
162
|
-
self,
|
163
|
-
model_name: str,
|
164
|
-
) -> Callable[[Any], List[float]]:
|
165
|
-
"""Create an embedding function from a model name."""
|
166
|
-
|
167
|
-
def embedding_function(text: Any) -> List[float]:
|
168
|
-
if not isinstance(text, str):
|
169
|
-
text = str(text)
|
170
|
-
|
171
|
-
# Filter out None values from embedding parameters
|
172
|
-
embedding_kwargs = {
|
173
|
-
k: v for k, v in self._embedding_params.items() if v is not None
|
174
|
-
}
|
175
|
-
embedding_kwargs["model"] = model_name
|
176
|
-
embedding_kwargs["input"] = text
|
177
|
-
|
178
|
-
response = create_embeddings(**embedding_kwargs)
|
179
|
-
return response.data[0].embedding
|
180
|
-
|
181
|
-
return embedding_function
|
182
|
-
|
183
|
-
def _init_qdrant_client(self):
|
184
|
-
"""Initialize the Qdrant client and collection."""
|
185
|
-
config = self._qdrant_config
|
186
|
-
|
187
|
-
if "path" in config:
|
188
|
-
# Persistent local storage
|
189
|
-
self._client = QdrantClient(path=config["path"])
|
190
|
-
elif "host" in config:
|
191
|
-
# Remote Qdrant server
|
192
|
-
self._client = QdrantClient(
|
193
|
-
host=config.get("host", "localhost"),
|
194
|
-
port=config.get("port", 6333),
|
195
|
-
grpc_port=config.get("grpc_port", 6334),
|
196
|
-
prefer_grpc=config.get("prefer_grpc", False),
|
197
|
-
api_key=config.get("api_key"),
|
198
|
-
timeout=config.get("timeout"),
|
199
|
-
)
|
200
|
-
else:
|
201
|
-
# In-memory database (default)
|
202
|
-
self._client = QdrantClient(":memory:")
|
203
|
-
|
204
|
-
# Create collection if it doesn't exist
|
205
|
-
try:
|
206
|
-
collections = self._client.get_collections()
|
207
|
-
collection_names = [col.name for col in collections.collections]
|
208
|
-
|
209
|
-
if self.name not in collection_names:
|
210
|
-
self._client.create_collection(
|
211
|
-
collection_name=self.name,
|
212
|
-
vectors_config=VectorParams(
|
213
|
-
size=self.vector_size, distance=self.distance_metric
|
214
|
-
),
|
215
|
-
)
|
216
|
-
except Exception as e:
|
217
|
-
# Collection might already exist or other issue
|
218
|
-
pass
|
219
|
-
|
220
|
-
def _ensure_uuid(self, id_str: str) -> str:
|
221
|
-
"""Convert a string ID to a UUID string, or validate if already a UUID."""
|
222
|
-
# Check if it's already a valid UUID
|
223
|
-
try:
|
224
|
-
uuid.UUID(id_str)
|
225
|
-
return id_str
|
226
|
-
except ValueError:
|
227
|
-
# Not a valid UUID, create a deterministic one
|
228
|
-
new_uuid = str(uuid.uuid5(self._NAMESPACE_UUID, id_str))
|
229
|
-
self._id_mapping[id_str] = new_uuid
|
230
|
-
return new_uuid
|
231
|
-
|
232
|
-
def __repr__(self) -> str:
|
233
|
-
item_count = len(self._items) if self._storage_backend is None else "managed"
|
234
|
-
return f"<{self.__class__.__name__} name='{self.name}' vector_size={self.vector_size} items={item_count}>"
|
235
|
-
|
236
|
-
def _calculate_expires_at(self, ttl: Optional[int]) -> Optional[datetime]:
|
237
|
-
"""Calculate expiry time based on TTL."""
|
238
|
-
if ttl is None:
|
239
|
-
ttl = self.default_ttl
|
240
|
-
if ttl and ttl > 0:
|
241
|
-
return datetime.now(timezone.utc) + timedelta(seconds=ttl)
|
242
|
-
return None
|
243
|
-
|
244
|
-
def _is_expired(self, expires_at: Optional[datetime]) -> bool:
|
245
|
-
"""Check if an item has expired."""
|
246
|
-
if expires_at is None:
|
247
|
-
return False
|
248
|
-
now = datetime.now(timezone.utc)
|
249
|
-
if expires_at.tzinfo is None:
|
250
|
-
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
251
|
-
return now >= expires_at
|
252
|
-
|
253
|
-
def _match_filters(
|
254
|
-
self, stored: Optional[Filters], query: Optional[Filters]
|
255
|
-
) -> bool:
|
256
|
-
"""Check if stored filters match query filters."""
|
257
|
-
if query is None:
|
258
|
-
return True
|
259
|
-
if stored is None:
|
260
|
-
return False
|
261
|
-
return all(stored.get(k) == v for k, v in query.items())
|
262
|
-
|
263
|
-
def _prepare_vector(self, entry: Any) -> List[float]:
|
264
|
-
"""Prepare vector from entry using embedding function or direct vector."""
|
265
|
-
if self._embedding_function:
|
266
|
-
return self._embedding_function(entry)
|
267
|
-
elif isinstance(entry, dict) and "vector" in entry:
|
268
|
-
vector = entry["vector"]
|
269
|
-
if isinstance(vector, np.ndarray):
|
270
|
-
return vector.tolist()
|
271
|
-
elif isinstance(vector, list):
|
272
|
-
return vector
|
273
|
-
else:
|
274
|
-
raise ValueError("Vector must be a list or numpy array")
|
275
|
-
elif isinstance(entry, (list, np.ndarray)):
|
276
|
-
if isinstance(entry, np.ndarray):
|
277
|
-
return entry.tolist()
|
278
|
-
return entry
|
279
|
-
else:
|
280
|
-
raise ValueError(
|
281
|
-
"Entry must contain 'vector' key, be a vector itself, "
|
282
|
-
"or embedding_function must be provided"
|
283
|
-
)
|
284
|
-
|
285
|
-
def _build_qdrant_filter(self, filters: Optional[Filters]) -> Optional[Filter]:
|
286
|
-
"""Build Qdrant filter from filters dict."""
|
287
|
-
if not filters:
|
288
|
-
return None
|
289
|
-
|
290
|
-
conditions = []
|
291
|
-
for key, value in filters.items():
|
292
|
-
conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
|
293
|
-
|
294
|
-
if len(conditions) == 1:
|
295
|
-
return Filter(must=[conditions[0]])
|
296
|
-
else:
|
297
|
-
return Filter(must=conditions)
|
298
|
-
|
299
|
-
def get(self, id: str, *, filters: Optional[Filters] = None) -> Optional[Object]:
|
300
|
-
"""Get an item by ID."""
|
301
|
-
if self._storage_backend is not None:
|
302
|
-
# Delegate to storage backend (Database instance)
|
303
|
-
return self._storage_backend.get(id, collection=self.name, filters=filters)
|
304
|
-
|
305
|
-
# Convert ID to UUID if needed
|
306
|
-
uuid_id = self._ensure_uuid(id)
|
307
|
-
|
308
|
-
# Independent operation
|
309
|
-
try:
|
310
|
-
points = self._client.retrieve(
|
311
|
-
collection_name=self.name,
|
312
|
-
ids=[uuid_id],
|
313
|
-
with_payload=True,
|
314
|
-
with_vectors=False,
|
315
|
-
)
|
316
|
-
|
317
|
-
if not points:
|
318
|
-
return None
|
319
|
-
|
320
|
-
point = points[0]
|
321
|
-
payload = point.payload or {}
|
322
|
-
|
323
|
-
# Check expiration
|
324
|
-
expires_at_str = payload.get("expires_at")
|
325
|
-
if expires_at_str:
|
326
|
-
expires_at = datetime.fromisoformat(expires_at_str)
|
327
|
-
if self._is_expired(expires_at):
|
328
|
-
# Delete expired item
|
329
|
-
self._client.delete(
|
330
|
-
collection_name=self.name, points_selector=[uuid_id]
|
331
|
-
)
|
332
|
-
return None
|
333
|
-
|
334
|
-
# Check filters - they are stored as top-level fields in payload
|
335
|
-
if filters:
|
336
|
-
for key, value in filters.items():
|
337
|
-
if payload.get(key) != value:
|
338
|
-
return None
|
339
|
-
|
340
|
-
return payload.get("value")
|
341
|
-
|
342
|
-
except Exception:
|
343
|
-
return None
|
344
|
-
|
345
|
-
def add(
|
346
|
-
self,
|
347
|
-
entry: Object,
|
348
|
-
id: Optional[str] = None,
|
349
|
-
*,
|
350
|
-
filters: Optional[Filters] = None,
|
351
|
-
ttl: Optional[int] = None,
|
352
|
-
) -> str:
|
353
|
-
"""Add an item to the collection.
|
354
|
-
|
355
|
-
Args:
|
356
|
-
entry: The object/data to store
|
357
|
-
id: Optional ID for the item (will generate UUID if not provided)
|
358
|
-
filters: Optional metadata filters
|
359
|
-
ttl: Time-to-live in seconds
|
360
|
-
|
361
|
-
Returns:
|
362
|
-
The ID of the added item
|
363
|
-
"""
|
364
|
-
if self._storage_backend is not None:
|
365
|
-
# Delegate to storage backend
|
366
|
-
self._storage_backend.add(
|
367
|
-
entry, id=id, collection=self.name, filters=filters, ttl=ttl
|
368
|
-
)
|
369
|
-
return id or str(uuid.uuid4())
|
370
|
-
|
371
|
-
# Independent operation
|
372
|
-
item_id = id or str(uuid.uuid4())
|
373
|
-
# Convert to UUID if needed
|
374
|
-
uuid_id = self._ensure_uuid(item_id)
|
375
|
-
|
376
|
-
expires_at = self._calculate_expires_at(ttl)
|
377
|
-
created_at = datetime.now(timezone.utc)
|
378
|
-
|
379
|
-
# Prepare vector
|
380
|
-
vector = self._prepare_vector(entry)
|
381
|
-
|
382
|
-
if len(vector) != self.vector_size:
|
383
|
-
raise ValueError(
|
384
|
-
f"Vector size {len(vector)} doesn't match collection size {self.vector_size}"
|
385
|
-
)
|
386
|
-
|
387
|
-
# Prepare payload - store original ID if converted
|
388
|
-
payload = {
|
389
|
-
"value": entry,
|
390
|
-
"created_at": created_at.isoformat(),
|
391
|
-
"updated_at": created_at.isoformat(),
|
392
|
-
}
|
393
|
-
|
394
|
-
# Add filter fields as top-level payload fields
|
395
|
-
if filters:
|
396
|
-
for key, value in filters.items():
|
397
|
-
payload[key] = value
|
398
|
-
|
399
|
-
# Store original ID if it was converted
|
400
|
-
if item_id != uuid_id:
|
401
|
-
payload["original_id"] = item_id
|
402
|
-
|
403
|
-
if expires_at:
|
404
|
-
payload["expires_at"] = expires_at.isoformat()
|
405
|
-
|
406
|
-
# Store in memory with UUID
|
407
|
-
self._items[uuid_id] = payload
|
408
|
-
|
409
|
-
# Create point and upsert to Qdrant
|
410
|
-
point = PointStruct(id=uuid_id, vector=vector, payload=payload)
|
411
|
-
|
412
|
-
self._client.upsert(collection_name=self.name, points=[point])
|
413
|
-
|
414
|
-
return item_id
|
415
|
-
|
416
|
-
def query(
|
417
|
-
self,
|
418
|
-
query: Optional[str] = None,
|
419
|
-
*,
|
420
|
-
filters: Optional[Filters] = None,
|
421
|
-
limit: Optional[int] = None,
|
422
|
-
) -> List[Object]:
|
423
|
-
"""Query items from the collection.
|
424
|
-
|
425
|
-
Args:
|
426
|
-
query: Search query string. If provided, performs semantic similarity search.
|
427
|
-
filters: Optional filters to apply
|
428
|
-
limit: Maximum number of results to return
|
429
|
-
"""
|
430
|
-
if self._storage_backend is not None:
|
431
|
-
return self._storage_backend.query(
|
432
|
-
collection=self.name,
|
433
|
-
filters=filters,
|
434
|
-
search=query,
|
435
|
-
limit=limit,
|
436
|
-
)
|
437
|
-
|
438
|
-
# For basic query without vector search, just return all items with filters
|
439
|
-
if query is None:
|
440
|
-
return self._query_all(filters=filters, limit=limit)
|
441
|
-
|
442
|
-
# If search is provided but no embedding function, treat as error
|
443
|
-
if self._embedding_function is None:
|
444
|
-
raise ValueError(
|
445
|
-
"Search query provided but no embedding_function configured. "
|
446
|
-
"Use vector_search() for direct vector similarity search."
|
447
|
-
)
|
448
|
-
|
449
|
-
# Convert search to vector and perform similarity search
|
450
|
-
query_vector = self._embedding_function(query)
|
451
|
-
return self.vector_search(
|
452
|
-
query_vector=query_vector, filters=filters, limit=limit
|
453
|
-
)
|
454
|
-
|
455
|
-
def _query_all(
|
456
|
-
self,
|
457
|
-
*,
|
458
|
-
filters: Optional[Filters] = None,
|
459
|
-
limit: Optional[int] = None,
|
460
|
-
) -> List[Object]:
|
461
|
-
"""Query all items with optional filters (no vector search)."""
|
462
|
-
try:
|
463
|
-
# Scroll through all points
|
464
|
-
points, _ = self._client.scroll(
|
465
|
-
collection_name=self.name,
|
466
|
-
scroll_filter=self._build_qdrant_filter(filters),
|
467
|
-
limit=limit or 100,
|
468
|
-
with_payload=True,
|
469
|
-
with_vectors=False,
|
470
|
-
)
|
471
|
-
|
472
|
-
results = []
|
473
|
-
for point in points:
|
474
|
-
payload = point.payload or {}
|
475
|
-
|
476
|
-
# Check expiration
|
477
|
-
expires_at_str = payload.get("expires_at")
|
478
|
-
if expires_at_str:
|
479
|
-
expires_at = datetime.fromisoformat(expires_at_str)
|
480
|
-
if self._is_expired(expires_at):
|
481
|
-
continue
|
482
|
-
|
483
|
-
results.append(payload.get("value"))
|
484
|
-
|
485
|
-
return results
|
486
|
-
|
487
|
-
except Exception:
|
488
|
-
return []
|
489
|
-
|
490
|
-
def vector_search(
|
491
|
-
self,
|
492
|
-
query_vector: Union[List[float], np.ndarray],
|
493
|
-
*,
|
494
|
-
filters: Optional[Filters] = None,
|
495
|
-
limit: int = 10,
|
496
|
-
score_threshold: Optional[float] = None,
|
497
|
-
) -> List[Object]:
|
498
|
-
"""
|
499
|
-
Perform vector similarity search.
|
500
|
-
|
501
|
-
Args:
|
502
|
-
query_vector: Query vector for similarity search
|
503
|
-
filters: Optional filters to apply
|
504
|
-
limit: Maximum number of results to return (default: 10)
|
505
|
-
score_threshold: Minimum similarity score threshold
|
506
|
-
|
507
|
-
Returns:
|
508
|
-
List of matching objects sorted by similarity score
|
509
|
-
"""
|
510
|
-
if isinstance(query_vector, np.ndarray):
|
511
|
-
query_vector = query_vector.tolist()
|
512
|
-
|
513
|
-
if len(query_vector) != self.vector_size:
|
514
|
-
raise ValueError(
|
515
|
-
f"Query vector size {len(query_vector)} doesn't match collection size {self.vector_size}"
|
516
|
-
)
|
517
|
-
|
518
|
-
try:
|
519
|
-
results = self._client.query_points(
|
520
|
-
collection_name=self.name,
|
521
|
-
query=query_vector,
|
522
|
-
query_filter=self._build_qdrant_filter(filters),
|
523
|
-
limit=limit,
|
524
|
-
score_threshold=score_threshold,
|
525
|
-
with_payload=True,
|
526
|
-
with_vectors=False,
|
527
|
-
)
|
528
|
-
|
529
|
-
matches = []
|
530
|
-
for result in results.points:
|
531
|
-
payload = result.payload or {}
|
532
|
-
|
533
|
-
# Check expiration
|
534
|
-
expires_at_str = payload.get("expires_at")
|
535
|
-
if expires_at_str:
|
536
|
-
expires_at = datetime.fromisoformat(expires_at_str)
|
537
|
-
if self._is_expired(expires_at):
|
538
|
-
continue
|
539
|
-
|
540
|
-
matches.append(payload.get("value"))
|
541
|
-
|
542
|
-
return matches
|
543
|
-
|
544
|
-
except Exception:
|
545
|
-
return []
|
546
|
-
|
547
|
-
def get_vector(self, id: str) -> Optional[List[float]]:
|
548
|
-
"""Get the vector for a specific item by ID."""
|
549
|
-
# Convert ID to UUID if needed
|
550
|
-
uuid_id = self._ensure_uuid(id)
|
551
|
-
|
552
|
-
try:
|
553
|
-
points = self._client.retrieve(
|
554
|
-
collection_name=self.name,
|
555
|
-
ids=[uuid_id],
|
556
|
-
with_payload=False,
|
557
|
-
with_vectors=True,
|
558
|
-
)
|
559
|
-
|
560
|
-
if not points:
|
561
|
-
return None
|
562
|
-
|
563
|
-
vector = points[0].vector
|
564
|
-
if isinstance(vector, dict):
|
565
|
-
# Handle named vectors if used
|
566
|
-
return list(vector.values())[0] if vector else None
|
567
|
-
return vector
|
568
|
-
|
569
|
-
except Exception:
|
570
|
-
return None
|
571
|
-
|
572
|
-
def delete(self, id: str) -> bool:
|
573
|
-
"""Delete an item by ID."""
|
574
|
-
# Convert ID to UUID if needed
|
575
|
-
uuid_id = self._ensure_uuid(id)
|
576
|
-
|
577
|
-
try:
|
578
|
-
self._client.delete(collection_name=self.name, points_selector=[uuid_id])
|
579
|
-
# Remove from in-memory storage if exists
|
580
|
-
self._items.pop(uuid_id, None)
|
581
|
-
return True
|
582
|
-
except Exception:
|
583
|
-
return False
|
584
|
-
|
585
|
-
def count(self, *, filters: Optional[Filters] = None) -> int:
|
586
|
-
"""Count items in the collection."""
|
587
|
-
try:
|
588
|
-
info = self._client.count(
|
589
|
-
collection_name=self.name,
|
590
|
-
count_filter=self._build_qdrant_filter(filters),
|
591
|
-
exact=True,
|
592
|
-
)
|
593
|
-
return info.count
|
594
|
-
except Exception:
|
595
|
-
return 0
|
596
|
-
|
597
|
-
def attach_to_database(self, database: Any) -> None:
|
598
|
-
"""Attach this collection to a database instance."""
|
599
|
-
self._storage_backend = database
|
600
|
-
# Ensure the collection exists in the database
|
601
|
-
database.create_collection(
|
602
|
-
self.name, schema=self.schema, default_ttl=self.default_ttl
|
603
|
-
)
|
@@ -1,21 +0,0 @@
|
|
1
|
-
"""hammad.data.databases"""
|
2
|
-
|
3
|
-
from typing import TYPE_CHECKING
|
4
|
-
from ..._core._utils._import_utils import _auto_create_getattr_loader
|
5
|
-
|
6
|
-
if TYPE_CHECKING:
|
7
|
-
from .database import Database, create_database
|
8
|
-
|
9
|
-
|
10
|
-
__all__ = (
|
11
|
-
"Database",
|
12
|
-
"create_database",
|
13
|
-
)
|
14
|
-
|
15
|
-
|
16
|
-
__getattr__ = _auto_create_getattr_loader(__all__)
|
17
|
-
|
18
|
-
|
19
|
-
def __dir__() -> list[str]:
|
20
|
-
"""Get the attributes of the data.databases module."""
|
21
|
-
return list(__all__)
|