brainlessdb 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,628 @@
1
+ """Collection class for managing entities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import logging
7
+ import uuid as uuid_module
8
+ from collections.abc import AsyncIterator, Iterator
9
+ from dataclasses import asdict, is_dataclass
10
+ from typing import TYPE_CHECKING, Any, Generic, TypeVar
11
+
12
+ from brainless.entity import Entity, _make_tracked_instance
13
+ from brainless.schema import Schema, infer_schema_from_dict
14
+
15
+ if TYPE_CHECKING:
16
+ from brainless.bucket import Bucket
17
+ from brainless.client import Brainless
18
+
19
+ _log = logging.getLogger(__name__)
20
+
21
+ T = TypeVar("T")
22
+ R = TypeVar("R") # For typed() return
23
+
24
+
25
+ def _invert(value: Any) -> Any:
26
+ """Invert value for descending sort."""
27
+ if isinstance(value, (int, float)):
28
+ return -value
29
+ if isinstance(value, str):
30
+ # Invert each character's ordinal
31
+ return [-ord(c) for c in value]
32
+ # Fallback - wrap in tuple that sorts reversed
33
+ return value
34
+
35
+
36
+ class Collection(Generic[T]):
37
+ """Collection of entities backed by NATS KV bucket
38
+
39
+ Provides CRUD operations, dict-style access, iteration, and filtering.
40
+ Schema is inferred from the first add() call.
41
+
42
+ Can be typed for IDE support:
43
+ queue: Collection[QueueItem] = brainless.queue.typed(QueueItem)
44
+ """
45
+
46
+ def __init__(self, client: Brainless, name: str) -> None:
47
+ self._client = client
48
+ self._name = name
49
+ self._schema: Schema | None = None
50
+ self._entities: dict[str, Entity] = {}
51
+ self._dirty: set[str] = set()
52
+ self._deleted: set[str] = set()
53
+ self._bucket: Bucket | None = None
54
+ self._loaded = False
55
+ self._cast_type: type | None = None
56
+ # Lazy indexes: field -> {value -> {uuid, ...}}
57
+ self._indexes: dict[str, dict[Any, set[str]]] = {}
58
+
59
+ @property
60
+ def name(self) -> str:
61
+ return self._name
62
+
63
+ @property
64
+ def schema(self) -> Schema | None:
65
+ return self._schema
66
+
67
+ def typed(self, cls: type[R]) -> Collection[R]:
68
+ """Set dataclass type for query results
69
+
70
+ When set, find/filter/all/order_by/iteration return typed instances.
71
+
72
+ @param cls: Dataclass type to convert results to
73
+ @return: Self for chaining (typed as Collection[R] for IDE support)
74
+ """
75
+ self._cast_type = cls
76
+ return self # type: ignore[return-value]
77
+
78
+ def _convert(self, entity: Entity) -> T:
79
+ """Convert entity to tracked cast type if set."""
80
+ if self._cast_type is None:
81
+ return entity # type: ignore[return-value]
82
+ return _make_tracked_instance(self._cast_type, entity) # type: ignore[return-value]
83
+
84
+ def _generate_uuid(self) -> str:
85
+ """Generate UUID1 with location-based node."""
86
+ location_hash = hashlib.sha256(self._client.location.encode()).digest()
87
+ node = int.from_bytes(location_hash[:6], "big")
88
+ return str(uuid_module.uuid1(node=node))
89
+
90
+ def _infer_schema(self, data: dict[str, Any]) -> None:
91
+ """Infer and lock schema from first entity data."""
92
+ self._schema = infer_schema_from_dict(self._name, data)
93
+ self._schema.lock()
94
+ _log.info(
95
+ "Inferred schema for '%s': %s",
96
+ self._name,
97
+ list(self._schema.fields.keys()),
98
+ )
99
+
100
+ def mark_dirty(self, entity: Entity) -> None:
101
+ """Mark entity for background flush."""
102
+ self._dirty.add(entity.uuid)
103
+
104
+ def _build_index(self, field: str) -> dict[Any, set[str]]:
105
+ """Build index for a field from all entities."""
106
+ index: dict[Any, set[str]] = {}
107
+ for uuid, entity in self._entities.items():
108
+ if field in entity:
109
+ value = entity[field]
110
+ if value not in index:
111
+ index[value] = set()
112
+ index[value].add(uuid)
113
+ self._indexes[field] = index
114
+ _log.debug("Built index for '%s.%s' (%d values)", self._name, field, len(index))
115
+ return index
116
+
117
+ def _get_index(self, field: str) -> dict[Any, set[str]]:
118
+ """Get or build index for field."""
119
+ if field not in self._indexes:
120
+ return self._build_index(field)
121
+ return self._indexes[field]
122
+
123
+ def _index_add(self, entity: Entity) -> None:
124
+ """Add entity to existing indexes."""
125
+ for field, index in self._indexes.items():
126
+ if field in entity:
127
+ value = entity[field]
128
+ if value not in index:
129
+ index[value] = set()
130
+ index[value].add(entity.uuid)
131
+
132
+ def _index_remove(self, entity: Entity) -> None:
133
+ """Remove entity from all indexes."""
134
+ for field, index in self._indexes.items():
135
+ if field in entity:
136
+ value = entity[field]
137
+ if value in index:
138
+ index[value].discard(entity.uuid)
139
+ if not index[value]:
140
+ del index[value]
141
+
142
+ def on_field_change(
143
+ self,
144
+ entity: Entity,
145
+ field: str,
146
+ old_value: Any,
147
+ new_value: Any,
148
+ ) -> None:
149
+ """Called by Entity when a field value changes."""
150
+ if field not in self._indexes:
151
+ return
152
+
153
+ index = self._indexes[field]
154
+
155
+ # Remove from old value
156
+ if old_value in index:
157
+ index[old_value].discard(entity.uuid)
158
+ if not index[old_value]:
159
+ del index[old_value]
160
+
161
+ # Add to new value
162
+ if new_value not in index:
163
+ index[new_value] = set()
164
+ index[new_value].add(entity.uuid)
165
+
166
+ async def _ensure_bucket(self) -> Bucket | None:
167
+ """Get or create bucket for this collection."""
168
+ if self._bucket is None and self._client.connected:
169
+ self._bucket = await self._client.get_bucket(self._name)
170
+ return self._bucket
171
+
172
+ async def _ensure_loaded(self) -> None:
173
+ """Ensure collection is loaded from NATS."""
174
+ if not self._loaded:
175
+ await self.load()
176
+
177
+ async def load(self) -> int:
178
+ """Load all entities from NATS bucket
179
+
180
+ @return: Number of entities loaded
181
+ """
182
+ if self._loaded:
183
+ return 0
184
+
185
+ bucket = await self._ensure_bucket()
186
+ if bucket is None:
187
+ self._loaded = True
188
+ return 0
189
+
190
+ data = await bucket.all()
191
+ count = 0
192
+
193
+ for uuid, entity_data in data.items():
194
+ # Strip uuid from data if present (legacy data)
195
+ entity_data.pop("uuid", None)
196
+
197
+ # Infer schema from first record if not set
198
+ if self._schema is None and entity_data:
199
+ self._infer_schema(entity_data)
200
+
201
+ entity = Entity(self, uuid, entity_data)
202
+ self._entities[uuid] = entity
203
+ count += 1
204
+
205
+ self._loaded = True
206
+ if count > 0:
207
+ _log.info("Loaded %d entities into '%s'", count, self._name)
208
+ return count
209
+
210
+ async def flush(self) -> int:
211
+ """Flush dirty entities and deletions to NATS bucket
212
+
213
+ @return: Number of operations performed
214
+ """
215
+ bucket = await self._ensure_bucket()
216
+ ops = 0
217
+
218
+ # Flush dirty entities
219
+ if self._dirty:
220
+ to_flush = list(self._dirty)
221
+ for uuid in to_flush:
222
+ entity = self._entities.get(uuid)
223
+ if entity is None:
224
+ self._dirty.discard(uuid)
225
+ continue
226
+
227
+ if bucket is not None:
228
+ await bucket.put(uuid, entity.data)
229
+ entity.mark_clean()
230
+ self._dirty.discard(uuid)
231
+ ops += 1
232
+
233
+ # Flush deletions
234
+ if self._deleted:
235
+ to_delete = list(self._deleted)
236
+ for uuid in to_delete:
237
+ if bucket is not None:
238
+ await bucket.delete(uuid)
239
+ self._deleted.discard(uuid)
240
+ ops += 1
241
+
242
+ if ops > 0:
243
+ _log.debug("Flushed %d operations from '%s'", ops, self._name)
244
+
245
+ return ops
246
+
247
+ def add(
248
+ self,
249
+ data: dict[str, Any] | Any | None = None,
250
+ **kwargs: Any,
251
+ ) -> T:
252
+ """Add new entity to collection
253
+
254
+ Accepts dict, dataclass instance, or keyword arguments.
255
+ Schema is inferred from the first add() call.
256
+
257
+ If a dataclass with a uuid field is passed, the generated uuid
258
+ is set on the original object for convenience.
259
+
260
+ @param data: Dictionary or dataclass instance
261
+ @param kwargs: Field values as keyword arguments
262
+ @return: Created entity (typed if typed() set)
263
+ """
264
+ input_dataclass = None
265
+
266
+ # Normalize input to dict
267
+ if data is None:
268
+ entity_data = kwargs
269
+ elif is_dataclass(data) and not isinstance(data, type):
270
+ input_dataclass = data
271
+ entity_data = {**asdict(data), **kwargs}
272
+ elif isinstance(data, dict):
273
+ entity_data = {**data, **kwargs}
274
+ else:
275
+ raise TypeError(f"Expected dict or dataclass, got {type(data).__name__}")
276
+
277
+ if not entity_data:
278
+ raise ValueError("Cannot add empty entity")
279
+
280
+ # Generate uuid
281
+ entity_uuid = self._generate_uuid()
282
+
283
+ # Set uuid on input dataclass if it has a uuid field
284
+ if input_dataclass is not None and hasattr(input_dataclass, "uuid"):
285
+ input_dataclass.uuid = entity_uuid
286
+
287
+ # Remove uuid from data (stored as key, not in value)
288
+ entity_data.pop("uuid", None)
289
+
290
+ # Infer schema on first add (after uuid removed)
291
+ if self._schema is None:
292
+ self._infer_schema(entity_data)
293
+
294
+ # Validate and apply defaults
295
+ entity_data = self._schema.apply_defaults(entity_data)
296
+ self._schema.validate(entity_data)
297
+
298
+ # Create entity
299
+ entity = Entity(self, entity_uuid, entity_data)
300
+ self._entities[entity_uuid] = entity
301
+
302
+ # Maintain indexes
303
+ self._index_add(entity)
304
+
305
+ # Mark for persistence
306
+ entity.mark_dirty()
307
+
308
+ return self._convert(entity)
309
+
310
+ async def get(self, uuid: str) -> Entity | None:
311
+ """Get entity by UUID
312
+
313
+ @param uuid: Entity UUID
314
+ @return: Entity or None if not found
315
+ """
316
+ await self._ensure_loaded()
317
+ return self._entities.get(uuid)
318
+
319
+ def delete(self, entity: str | Entity | Any) -> bool:
320
+ """Delete entity by UUID, Entity, or object with uuid attribute
321
+
322
+ Removal from NATS bucket happens on next flush.
323
+
324
+ @param entity: Entity UUID string, Entity instance, or object with uuid attr
325
+ @return: True if deleted, False if not found
326
+ """
327
+ if isinstance(entity, str):
328
+ uuid = entity
329
+ elif hasattr(entity, "uuid"):
330
+ uuid = entity.uuid
331
+ else:
332
+ raise TypeError(f"Cannot get uuid from {type(entity).__name__}")
333
+ existing = self._entities.get(uuid)
334
+ if existing is None:
335
+ return False
336
+
337
+ # Remove from indexes
338
+ self._index_remove(existing)
339
+
340
+ del self._entities[uuid]
341
+ self._dirty.discard(uuid)
342
+ self._deleted.add(uuid)
343
+ return True
344
+
345
+ def _get_nested_value(self, entity: Entity, key: str, expected: Any = None) -> tuple[bool, Any]:
346
+ """Get value from entity, supporting nested access via __
347
+
348
+ When a list or dict is encountered, checks if ANY item matches.
349
+ If expected is provided, returns True if any path equals expected.
350
+
351
+ @param entity: Entity to get value from
352
+ @param key: Field name, supports __ for nested access
353
+ @param expected: If provided, check if any path matches this value
354
+ @return: (found, value) tuple
355
+ """
356
+ parts = key.split("__")
357
+ return self._traverse_path(entity, parts, expected)
358
+
359
+ def _traverse_path(self, value: Any, parts: list[str], expected: Any = None) -> tuple[bool, Any]:
360
+ """Recursively traverse a path through nested structures.
361
+
362
+ Handles dicts, Entities, and lists/dicts (checks if ANY item matches).
363
+ If expected is provided, checks if any terminal value equals expected.
364
+ """
365
+ if not parts:
366
+ # Terminal - check expected if provided
367
+ if expected is not None:
368
+ return value == expected, value
369
+ return True, value
370
+
371
+ part = parts[0]
372
+ remaining = parts[1:]
373
+
374
+ if isinstance(value, Entity):
375
+ if part not in value:
376
+ return False, None
377
+ return self._traverse_path(value[part], remaining, expected)
378
+
379
+ elif isinstance(value, dict):
380
+ if part in value:
381
+ return self._traverse_path(value[part], remaining, expected)
382
+ # part not in value - try iterating values (for dict[K, Dataclass] pattern)
383
+ for item in value.values():
384
+ found, result = self._traverse_path(item, [part] + remaining, expected)
385
+ if found:
386
+ return True, result
387
+ return False, None
388
+
389
+ elif isinstance(value, list):
390
+ # For lists, check if ANY item matches the remaining path
391
+ for item in value:
392
+ found, result = self._traverse_path(item, [part] + remaining, expected)
393
+ if found:
394
+ return True, result
395
+ return False, None
396
+
397
+ return False, None
398
+
399
+ def _matches(self, entity: Entity, criteria: dict[str, Any]) -> bool:
400
+ """Check if entity matches all criteria."""
401
+ for key, expected in criteria.items():
402
+ found, _ = self._get_nested_value(entity, key, expected)
403
+ if not found:
404
+ return False
405
+ return True
406
+
407
+ async def filter(self, **criteria: Any) -> list[T]:
408
+ """Filter entities by field values
409
+
410
+ Supports nested access via double underscore:
411
+ filter(caller__city="Prague")
412
+
413
+ Uses indexes for O(1) lookup on top-level fields.
414
+
415
+ @param criteria: Field-value pairs to match
416
+ @return: List of matching entities (typed if typed() set)
417
+ """
418
+ await self._ensure_loaded()
419
+ if not criteria:
420
+ return [self._convert(e) for e in self._entities.values()]
421
+
422
+ # Try indexed lookup for first top-level field
423
+ candidates: set[str] | None = None
424
+ remaining_criteria: dict[str, Any] = {}
425
+
426
+ for key, value in criteria.items():
427
+ if "__" not in key:
428
+ # Top-level field - use index
429
+ index = self._get_index(key)
430
+ uuids = index.get(value, set())
431
+ if candidates is None:
432
+ candidates = uuids.copy()
433
+ else:
434
+ candidates &= uuids
435
+ # Early exit if no matches
436
+ if not candidates:
437
+ return []
438
+ else:
439
+ remaining_criteria[key] = value
440
+
441
+ # If we have indexed candidates, filter those
442
+ if candidates is not None:
443
+ entities = [self._entities[u] for u in candidates if u in self._entities]
444
+ if remaining_criteria:
445
+ return [self._convert(e) for e in entities if self._matches(e, remaining_criteria)]
446
+ return [self._convert(e) for e in entities]
447
+
448
+ # No indexable fields, full scan
449
+ return [self._convert(e) for e in self._entities.values() if self._matches(e, criteria)]
450
+
451
+ async def find(self, **criteria: Any) -> T | None:
452
+ """Find first entity matching criteria
453
+
454
+ Supports nested access via double underscore:
455
+ find(caller__city="Prague")
456
+
457
+ Uses indexes for O(1) lookup on top-level fields.
458
+
459
+ @param criteria: Field-value pairs to match
460
+ @return: First matching entity (typed if typed() set) or None
461
+ """
462
+ await self._ensure_loaded()
463
+
464
+ if not criteria:
465
+ # Return first entity if any
466
+ for entity in self._entities.values():
467
+ return self._convert(entity)
468
+ return None
469
+
470
+ # Try indexed lookup for first top-level field
471
+ candidates: set[str] | None = None
472
+ remaining_criteria: dict[str, Any] = {}
473
+
474
+ for key, value in criteria.items():
475
+ if "__" not in key:
476
+ # Top-level field - use index
477
+ index = self._get_index(key)
478
+ uuids = index.get(value, set())
479
+ if candidates is None:
480
+ candidates = uuids.copy()
481
+ else:
482
+ candidates &= uuids
483
+ # Early exit if no matches
484
+ if not candidates:
485
+ return None
486
+ else:
487
+ remaining_criteria[key] = value
488
+
489
+ # If we have indexed candidates, search those
490
+ if candidates is not None:
491
+ for uuid in candidates:
492
+ entity = self._entities.get(uuid)
493
+ if entity is None:
494
+ continue
495
+ if remaining_criteria:
496
+ if self._matches(entity, remaining_criteria):
497
+ return self._convert(entity)
498
+ else:
499
+ return self._convert(entity)
500
+ return None
501
+
502
+ # No indexable fields, full scan
503
+ for entity in self._entities.values():
504
+ if self._matches(entity, criteria):
505
+ return self._convert(entity)
506
+ return None
507
+
508
+ async def all(self) -> list[T]:
509
+ """Get all entities in collection."""
510
+ await self._ensure_loaded()
511
+ return [self._convert(e) for e in self._entities.values()]
512
+
513
+ async def order_by(self, *keys: str, **criteria: Any) -> list[T]:
514
+ """Get entities sorted by key(s), optionally filtered
515
+
516
+ Use minus prefix for descending order: order_by("-created_at")
517
+ Multiple keys are applied in order: order_by("priority", "-created_at")
518
+
519
+ @param keys: Field names to sort by (prefix with - for descending)
520
+ @param criteria: Filter criteria (same as filter())
521
+ @return: Sorted list of entities (typed if typed() set)
522
+ """
523
+ await self._ensure_loaded()
524
+
525
+ # Get raw entities for sorting (filter without conversion)
526
+ if criteria:
527
+ candidates: set[str] | None = None
528
+ remaining_criteria: dict[str, Any] = {}
529
+ for key, value in criteria.items():
530
+ if "__" not in key:
531
+ index = self._get_index(key)
532
+ uuids = index.get(value, set())
533
+ if candidates is None:
534
+ candidates = uuids.copy()
535
+ else:
536
+ candidates &= uuids
537
+ if not candidates:
538
+ return []
539
+ else:
540
+ remaining_criteria[key] = value
541
+ if candidates is not None:
542
+ entities = [self._entities[u] for u in candidates if u in self._entities]
543
+ if remaining_criteria:
544
+ entities = [e for e in entities if self._matches(e, remaining_criteria)]
545
+ else:
546
+ entities = [e for e in self._entities.values() if self._matches(e, criteria)]
547
+ else:
548
+ entities = list(self._entities.values())
549
+
550
+ if not keys:
551
+ return [self._convert(e) for e in entities]
552
+
553
+ def sort_key(entity: Entity) -> tuple:
554
+ result = []
555
+ for key in keys:
556
+ descending = key.startswith("-")
557
+ field = key[1:] if descending else key
558
+
559
+ found, value = self._get_nested_value(entity, field)
560
+ if not found:
561
+ value = None
562
+
563
+ # None values sort last regardless of direction
564
+ if value is None:
565
+ result.append((1, None))
566
+ elif descending:
567
+ result.append((0, _invert(value)))
568
+ else:
569
+ result.append((0, value))
570
+ return tuple(result)
571
+
572
+ return [self._convert(e) for e in sorted(entities, key=sort_key)]
573
+
574
+ def count(self) -> int:
575
+ """Return number of entities in collection."""
576
+ return len(self._entities)
577
+
578
+ def clear(self) -> None:
579
+ """Remove all entities from collection (in-memory only)."""
580
+ self._entities.clear()
581
+ self._dirty.clear()
582
+ self._indexes.clear()
583
+
584
+ def __getitem__(self, uuid: str) -> Entity:
585
+ """Dict-style access by UUID."""
586
+ entity = self._entities.get(uuid)
587
+ if entity is None:
588
+ raise KeyError(uuid)
589
+ return entity
590
+
591
+ def __delitem__(self, entity: str | Entity | Any) -> None:
592
+ """Dict-style deletion."""
593
+ if not self.delete(entity):
594
+ if isinstance(entity, str):
595
+ uuid = entity
596
+ elif hasattr(entity, "uuid"):
597
+ uuid = entity.uuid
598
+ else:
599
+ uuid = str(entity)
600
+ raise KeyError(uuid)
601
+
602
+ def __contains__(self, entity: str | Entity | Any) -> bool:
603
+ """Check if entity exists in collection."""
604
+ if isinstance(entity, str):
605
+ uuid = entity
606
+ elif hasattr(entity, "uuid"):
607
+ uuid = entity.uuid
608
+ else:
609
+ return False
610
+ return uuid in self._entities
611
+
612
+ def __iter__(self) -> Iterator[T]:
613
+ """Iterate over all entities (converted if typed() set)."""
614
+ for entity in self._entities.values():
615
+ yield self._convert(entity)
616
+
617
+ async def __aiter__(self) -> AsyncIterator[T]:
618
+ """Async iterate over all entities (loads if needed)."""
619
+ await self._ensure_loaded()
620
+ for entity in self._entities.values():
621
+ yield self._convert(entity)
622
+
623
+ def __len__(self) -> int:
624
+ """Return number of entities."""
625
+ return len(self._entities)
626
+
627
+ def __repr__(self) -> str:
628
+ return f"<Collection '{self._name}' ({len(self._entities)} entities)>"