mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,970 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import uuid
|
|
8
|
+
import weakref
|
|
9
|
+
from collections import deque
|
|
10
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
11
|
+
from typing import (
|
|
12
|
+
Any,
|
|
13
|
+
Deque,
|
|
14
|
+
Dict,
|
|
15
|
+
Iterable,
|
|
16
|
+
List,
|
|
17
|
+
Literal,
|
|
18
|
+
Mapping,
|
|
19
|
+
MutableMapping,
|
|
20
|
+
Optional,
|
|
21
|
+
Sequence,
|
|
22
|
+
Tuple,
|
|
23
|
+
Type,
|
|
24
|
+
TypeVar,
|
|
25
|
+
Union,
|
|
26
|
+
cast,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
import aiologic
|
|
30
|
+
from pydantic import BaseModel
|
|
31
|
+
|
|
32
|
+
from mantisdk.types import (
|
|
33
|
+
Attempt,
|
|
34
|
+
FilterField,
|
|
35
|
+
FilterOptions,
|
|
36
|
+
PaginatedResult,
|
|
37
|
+
ResourcesUpdate,
|
|
38
|
+
Rollout,
|
|
39
|
+
SortOptions,
|
|
40
|
+
Span,
|
|
41
|
+
Worker,
|
|
42
|
+
)
|
|
43
|
+
from mantisdk.utils.metrics import MetricsBackend
|
|
44
|
+
|
|
45
|
+
from .base import (
|
|
46
|
+
AtomicLabels,
|
|
47
|
+
AtomicMode,
|
|
48
|
+
Collection,
|
|
49
|
+
DuplicatedPrimaryKeyError,
|
|
50
|
+
FilterMap,
|
|
51
|
+
KeyValue,
|
|
52
|
+
LightningCollections,
|
|
53
|
+
Queue,
|
|
54
|
+
ensure_numeric,
|
|
55
|
+
normalize_filter_options,
|
|
56
|
+
resolve_sort_options,
|
|
57
|
+
tracked,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
T = TypeVar("T") # Recommended to be a BaseModel, not a dict
|
|
61
|
+
K = TypeVar("K")
|
|
62
|
+
V = TypeVar("V")
|
|
63
|
+
|
|
64
|
+
logger = logging.getLogger(__name__)
|
|
65
|
+
|
|
66
|
+
# Nested structure type:
|
|
67
|
+
# dict[pk1] -> dict[pk2] -> ... -> item
|
|
68
|
+
ListBasedCollectionItemType = Union[
|
|
69
|
+
Dict[Any, "ListBasedCollectionItemType[T]"], # intermediate node
|
|
70
|
+
Dict[Any, T], # leaf node dictionary
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
MutationMode = Literal["insert", "update", "upsert", "delete"]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _item_matches_filters(
|
|
77
|
+
item: object,
|
|
78
|
+
filters: Optional[FilterMap],
|
|
79
|
+
filter_logic: Literal["and", "or"],
|
|
80
|
+
must_filters: Optional[FilterMap] = None,
|
|
81
|
+
) -> bool:
|
|
82
|
+
"""Check whether an item matches the provided filter definition.
|
|
83
|
+
|
|
84
|
+
Filter format:
|
|
85
|
+
|
|
86
|
+
```json
|
|
87
|
+
{
|
|
88
|
+
"_aggregate": "or",
|
|
89
|
+
"field_name": {
|
|
90
|
+
"exact": <value>,
|
|
91
|
+
"within": <iterable_of_allowed_values>,
|
|
92
|
+
"contains": <substring_or_element>,
|
|
93
|
+
},
|
|
94
|
+
...
|
|
95
|
+
}
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
Operators within the same field are stored in a unified pool and combined using
|
|
99
|
+
a universal logical operator.
|
|
100
|
+
"""
|
|
101
|
+
if must_filters and not _item_matches_filters(item, must_filters, "and"):
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
if not filters:
|
|
105
|
+
return True
|
|
106
|
+
|
|
107
|
+
all_conditions_match: List[bool] = []
|
|
108
|
+
|
|
109
|
+
for field_name, ops in filters.items():
|
|
110
|
+
item_value = getattr(item, field_name, None)
|
|
111
|
+
|
|
112
|
+
for op_name, expected in ops.items():
|
|
113
|
+
# Ignore no-op filters
|
|
114
|
+
if expected is None:
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
if op_name == "exact":
|
|
118
|
+
all_conditions_match.append(item_value == expected)
|
|
119
|
+
|
|
120
|
+
elif op_name == "within":
|
|
121
|
+
try:
|
|
122
|
+
all_conditions_match.append(item_value in expected) # type: ignore[arg-type]
|
|
123
|
+
except TypeError:
|
|
124
|
+
all_conditions_match.append(False)
|
|
125
|
+
|
|
126
|
+
elif op_name == "contains":
|
|
127
|
+
if item_value is None:
|
|
128
|
+
all_conditions_match.append(False)
|
|
129
|
+
elif isinstance(item_value, str) and isinstance(expected, str):
|
|
130
|
+
all_conditions_match.append(expected in item_value)
|
|
131
|
+
else:
|
|
132
|
+
# Fallback: treat as generic iterable containment.
|
|
133
|
+
try:
|
|
134
|
+
all_conditions_match.append(expected in item_value) # type: ignore[arg-type]
|
|
135
|
+
except TypeError:
|
|
136
|
+
all_conditions_match.append(False)
|
|
137
|
+
else:
|
|
138
|
+
raise ValueError(f"Unsupported filter operator '{op_name}' for field '{field_name}'")
|
|
139
|
+
|
|
140
|
+
return all(all_conditions_match) if filter_logic == "and" else any(all_conditions_match)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _get_sort_value(item: object, sort_by: str) -> Any:
|
|
144
|
+
"""Get a sort key for the given item/field.
|
|
145
|
+
|
|
146
|
+
- If the field name ends with '_time', values are treated as comparable timestamps.
|
|
147
|
+
- For other fields we try to infer a safe default from the Pydantic model annotation.
|
|
148
|
+
"""
|
|
149
|
+
value = getattr(item, sort_by, None)
|
|
150
|
+
|
|
151
|
+
if sort_by.endswith("_time"):
|
|
152
|
+
# For *_time fields, push missing values to the end.
|
|
153
|
+
return float("inf") if value is None else value
|
|
154
|
+
|
|
155
|
+
if value is None:
|
|
156
|
+
# Introspect model field type to choose a reasonable default for None.
|
|
157
|
+
model_fields = getattr(item.__class__, "model_fields", {})
|
|
158
|
+
if sort_by not in model_fields:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"Failed to sort items by '{sort_by}': field does not exist " f"on {item.__class__.__name__}"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
field_type_str = str(model_fields[sort_by].annotation)
|
|
164
|
+
if "str" in field_type_str or "Literal" in field_type_str:
|
|
165
|
+
return ""
|
|
166
|
+
if "int" in field_type_str:
|
|
167
|
+
return 0
|
|
168
|
+
if "float" in field_type_str:
|
|
169
|
+
return 0.0
|
|
170
|
+
raise ValueError(f"Failed to sort items by '{sort_by}': unsupported field type {field_type_str!r}")
|
|
171
|
+
|
|
172
|
+
return value
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class ListBasedCollection(Collection[T]):
|
|
176
|
+
"""In-memory implementation of Collection using a nested dict for O(1) primary-key lookup.
|
|
177
|
+
|
|
178
|
+
The internal structure is:
|
|
179
|
+
|
|
180
|
+
{
|
|
181
|
+
pk1_value: {
|
|
182
|
+
pk2_value: {
|
|
183
|
+
...
|
|
184
|
+
pkN_value: item
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
where the nesting depth equals the number of primary keys.
|
|
190
|
+
|
|
191
|
+
Sorting behavior:
|
|
192
|
+
|
|
193
|
+
1. If no sort_by is provided, the items are returned in the order of insertion.
|
|
194
|
+
2. If sort_by is provided, the items are sorted by the value of the sort_by field.
|
|
195
|
+
3. If the sort_by field is a timestamp, the null values are treated as infinity.
|
|
196
|
+
4. If the sort_by field is not a timestamp, the null values are treated as empty string
|
|
197
|
+
if the field is str-like, 0 if the field is int-like, 0.0 if the field is float-like.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
def __init__(
|
|
201
|
+
self,
|
|
202
|
+
items: List[T],
|
|
203
|
+
item_type: Type[T],
|
|
204
|
+
primary_keys: Sequence[str],
|
|
205
|
+
id: Optional[str] = None,
|
|
206
|
+
tracker: Optional[MetricsBackend] = None,
|
|
207
|
+
):
|
|
208
|
+
super().__init__(tracker=tracker)
|
|
209
|
+
if not primary_keys:
|
|
210
|
+
raise ValueError("primary_keys must be non-empty")
|
|
211
|
+
|
|
212
|
+
self._id = id if id is not None else str(uuid.uuid4())
|
|
213
|
+
self._items: Dict[Any, Any] = {}
|
|
214
|
+
self._size: int = 0
|
|
215
|
+
if issubclass(item_type, dict):
|
|
216
|
+
raise TypeError(f"Expect item to be not a dict, got {item_type.__name__}")
|
|
217
|
+
self._item_type: Type[T] = item_type
|
|
218
|
+
self._primary_keys: Tuple[str, ...] = tuple(primary_keys)
|
|
219
|
+
|
|
220
|
+
# Pre-populate the collection with the given items.
|
|
221
|
+
for item in items or []:
|
|
222
|
+
self._mutate_single(item, mode="insert")
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def collection_name(self) -> str:
|
|
226
|
+
return self._id
|
|
227
|
+
|
|
228
|
+
def primary_keys(self) -> Sequence[str]:
|
|
229
|
+
"""Return the primary key field names for this collection."""
|
|
230
|
+
return self._primary_keys
|
|
231
|
+
|
|
232
|
+
def item_type(self) -> Type[T]:
|
|
233
|
+
"""Return the Pydantic model type of items stored in this collection."""
|
|
234
|
+
return self._item_type
|
|
235
|
+
|
|
236
|
+
async def size(self) -> int:
|
|
237
|
+
"""Return the number of items stored in the collection."""
|
|
238
|
+
return self._size
|
|
239
|
+
|
|
240
|
+
def __repr__(self) -> str:
|
|
241
|
+
return f"<{self.__class__.__name__}[{self.item_type().__name__}] ({self._size})>"
|
|
242
|
+
|
|
243
|
+
# -------------------------------------------------------------------------
|
|
244
|
+
# Internal helpers
|
|
245
|
+
# -------------------------------------------------------------------------
|
|
246
|
+
|
|
247
|
+
def _ensure_item_type(self, item: T) -> None:
|
|
248
|
+
"""Validate that the item matches the declared item_type."""
|
|
249
|
+
if not isinstance(item, self._item_type):
|
|
250
|
+
raise TypeError(f"Expected item of type {self._item_type.__name__}, " f"got {type(item).__name__}")
|
|
251
|
+
|
|
252
|
+
def _extract_primary_key_values(self, item: T) -> Tuple[Any, ...]:
|
|
253
|
+
"""Extract the primary key values from an item.
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
ValueError: If any primary key is missing on the item.
|
|
257
|
+
"""
|
|
258
|
+
values: List[Any] = []
|
|
259
|
+
for key in self._primary_keys:
|
|
260
|
+
if not hasattr(item, key):
|
|
261
|
+
raise ValueError(f"Item {item} does not have primary key field '{key}'")
|
|
262
|
+
values.append(getattr(item, key))
|
|
263
|
+
return tuple(values)
|
|
264
|
+
|
|
265
|
+
def _render_key_values(self, key_values: Sequence[Any]) -> str:
|
|
266
|
+
return ", ".join(f"{name}={value!r}" for name, value in zip(self._primary_keys, key_values))
|
|
267
|
+
|
|
268
|
+
def _locate_node(
|
|
269
|
+
self,
|
|
270
|
+
key_values: Sequence[Any],
|
|
271
|
+
create_missing: bool,
|
|
272
|
+
) -> Tuple[MutableMapping[Any, Any], Any]:
|
|
273
|
+
"""Locate the parent mapping and final key for an item path.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
key_values: The sequence of primary key values.
|
|
277
|
+
create_missing: Whether to create intermediate dictionaries as needed.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
(parent_mapping, final_key)
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
KeyError: If the path does not exist and create_missing is False.
|
|
284
|
+
ValueError: If the internal structure is corrupted (non-dict where dict is expected).
|
|
285
|
+
"""
|
|
286
|
+
if not key_values:
|
|
287
|
+
raise ValueError("key_values must be non-empty")
|
|
288
|
+
|
|
289
|
+
current: MutableMapping[Any, Any] = self._items
|
|
290
|
+
for idx, value in enumerate(key_values):
|
|
291
|
+
is_last = idx == len(key_values) - 1
|
|
292
|
+
if is_last:
|
|
293
|
+
# At the final level, current[value] is the item (or will be).
|
|
294
|
+
return current, value # type: ignore
|
|
295
|
+
|
|
296
|
+
# Intermediate level: current[value] must be a dict.
|
|
297
|
+
if value not in current:
|
|
298
|
+
if not create_missing:
|
|
299
|
+
raise KeyError(f"Path does not exist for given primary keys: {self._render_key_values(key_values)}")
|
|
300
|
+
current[value] = {}
|
|
301
|
+
next_node = current[value] # type: ignore
|
|
302
|
+
if not isinstance(next_node, dict):
|
|
303
|
+
raise ValueError(f"Internal structure corrupted: expected dict, got {type(next_node)!r}") # type: ignore
|
|
304
|
+
current = next_node # type: ignore
|
|
305
|
+
|
|
306
|
+
# We should always return inside the loop.
|
|
307
|
+
raise RuntimeError("Unreachable")
|
|
308
|
+
|
|
309
|
+
def _mutate_single(self, item: T, mode: MutationMode, update_fields: Sequence[str] | None = None) -> Optional[T]:
|
|
310
|
+
"""Core mutation logic shared by insert, update, upsert, and delete."""
|
|
311
|
+
self._ensure_item_type(item)
|
|
312
|
+
key_values = self._extract_primary_key_values(item)
|
|
313
|
+
|
|
314
|
+
if mode in ("insert", "upsert"):
|
|
315
|
+
parent, final_key = self._locate_node(key_values, create_missing=True)
|
|
316
|
+
exists = final_key in parent
|
|
317
|
+
|
|
318
|
+
if mode == "insert":
|
|
319
|
+
if exists:
|
|
320
|
+
raise DuplicatedPrimaryKeyError(
|
|
321
|
+
f"Item already exists with primary key(s): {self._render_key_values(key_values)}"
|
|
322
|
+
)
|
|
323
|
+
parent[final_key] = item
|
|
324
|
+
self._size += 1
|
|
325
|
+
else: # upsert
|
|
326
|
+
if not exists:
|
|
327
|
+
self._size += 1
|
|
328
|
+
parent[final_key] = item
|
|
329
|
+
|
|
330
|
+
elif update_fields is None:
|
|
331
|
+
# update_or_insert: update all fields
|
|
332
|
+
parent[final_key] = item
|
|
333
|
+
|
|
334
|
+
else:
|
|
335
|
+
if not issubclass(self._item_type, BaseModel):
|
|
336
|
+
raise TypeError(
|
|
337
|
+
f"When using update_fields, the item type must be a Pydantic BaseModel, got {self._item_type.__name__}"
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Try to fetch the existing item
|
|
341
|
+
existing = parent[final_key]
|
|
342
|
+
if not isinstance(existing, self._item_type):
|
|
343
|
+
raise ValueError(
|
|
344
|
+
f"Internal structure corrupted: expected {self._item_type.__name__}, got {type(existing)!r}"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
if not isinstance(item, self._item_type):
|
|
348
|
+
raise TypeError(
|
|
349
|
+
f"When using update_fields, the item type must be a Pydantic BaseModel, got {type(item).__name__}"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
parent[final_key] = parent[final_key].model_copy(
|
|
353
|
+
update={field: getattr(item, field) for field in update_fields}
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
return parent[final_key]
|
|
357
|
+
|
|
358
|
+
elif mode in ("update", "delete"):
|
|
359
|
+
# For update/delete we must not create missing paths.
|
|
360
|
+
try:
|
|
361
|
+
parent, final_key = self._locate_node(key_values, create_missing=False)
|
|
362
|
+
except KeyError:
|
|
363
|
+
raise ValueError(
|
|
364
|
+
f"Item does not exist with primary key(s): {self._render_key_values(key_values)}"
|
|
365
|
+
) from None
|
|
366
|
+
|
|
367
|
+
if final_key not in parent:
|
|
368
|
+
raise ValueError(f"Item does not exist with primary key(s): {self._render_key_values(key_values)}")
|
|
369
|
+
|
|
370
|
+
if mode == "update":
|
|
371
|
+
if update_fields is None:
|
|
372
|
+
# replace the entire item
|
|
373
|
+
parent[final_key] = item
|
|
374
|
+
else:
|
|
375
|
+
if not issubclass(self._item_type, BaseModel):
|
|
376
|
+
raise TypeError(
|
|
377
|
+
f"When using update_fields, the item type must be a Pydantic BaseModel, got {self._item_type.__name__}"
|
|
378
|
+
)
|
|
379
|
+
if not isinstance(item, self._item_type):
|
|
380
|
+
raise TypeError(
|
|
381
|
+
f"When using update_fields, the item type must be a Pydantic BaseModel, got {type(item).__name__}"
|
|
382
|
+
)
|
|
383
|
+
parent[final_key] = parent[final_key].model_copy(
|
|
384
|
+
update={field: getattr(item, field) for field in update_fields}
|
|
385
|
+
)
|
|
386
|
+
return parent[final_key]
|
|
387
|
+
else: # delete
|
|
388
|
+
del parent[final_key]
|
|
389
|
+
self._size -= 1
|
|
390
|
+
else:
|
|
391
|
+
raise ValueError(f"Unknown mutation mode: {mode}")
|
|
392
|
+
|
|
393
|
+
def _iter_items(
|
|
394
|
+
self,
|
|
395
|
+
root: Optional[Mapping[Any, Any]] = None,
|
|
396
|
+
filters: Optional[FilterMap] = None,
|
|
397
|
+
must_filters: Optional[FilterMap] = None,
|
|
398
|
+
filter_logic: Literal["and", "or"] = "and",
|
|
399
|
+
) -> Iterable[T]:
|
|
400
|
+
"""Iterate over all items in the nested dictionary structure, optionally applying filters."""
|
|
401
|
+
if root is None:
|
|
402
|
+
root = self._items
|
|
403
|
+
if not root:
|
|
404
|
+
return
|
|
405
|
+
stack: List[Mapping[Any, Any]] = [root]
|
|
406
|
+
while stack:
|
|
407
|
+
node = stack.pop()
|
|
408
|
+
for value in node.values():
|
|
409
|
+
# Leaf nodes contain items; intermediate nodes are dicts.
|
|
410
|
+
if isinstance(value, self._item_type):
|
|
411
|
+
if _item_matches_filters(value, filters, filter_logic, must_filters):
|
|
412
|
+
yield value
|
|
413
|
+
elif isinstance(value, dict):
|
|
414
|
+
stack.append(value) # type: ignore
|
|
415
|
+
else:
|
|
416
|
+
raise ValueError(
|
|
417
|
+
f"Internal structure corrupted: expected dict or {self._item_type.__name__}, "
|
|
418
|
+
f"got {type(value)!r}"
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
def _iter_matching_items(
|
|
422
|
+
self,
|
|
423
|
+
filters: Optional[FilterMap],
|
|
424
|
+
must_filters: Optional[FilterMap],
|
|
425
|
+
filter_logic: Literal["and", "or"],
|
|
426
|
+
) -> Iterable[T]:
|
|
427
|
+
"""Efficiently iterate over items matching filters, using primary-key prefix when possible."""
|
|
428
|
+
# Fast path: when optional filters can't form a prefix, fall back to scanning.
|
|
429
|
+
if filter_logic != "and" and must_filters is None:
|
|
430
|
+
return self._iter_items(filters=filters, must_filters=must_filters, filter_logic=filter_logic)
|
|
431
|
+
|
|
432
|
+
# Try to derive a primary-key prefix from exact filters.
|
|
433
|
+
pk_values_prefix: List[Any] = []
|
|
434
|
+
prefix_sources: List[FilterMap] = []
|
|
435
|
+
if must_filters:
|
|
436
|
+
prefix_sources.append(must_filters)
|
|
437
|
+
if filter_logic == "and" and filters:
|
|
438
|
+
prefix_sources.append(filters)
|
|
439
|
+
|
|
440
|
+
for pk in self._primary_keys:
|
|
441
|
+
# combined_ops are: [{"exact": value}, {"within": [...]}, ...]
|
|
442
|
+
combined_ops: List[FilterField] = []
|
|
443
|
+
for source in prefix_sources:
|
|
444
|
+
field_ops = source.get(pk) # type: ignore[union-attr]
|
|
445
|
+
if field_ops:
|
|
446
|
+
combined_ops.append(field_ops)
|
|
447
|
+
if not combined_ops:
|
|
448
|
+
break
|
|
449
|
+
# Only allow a pure {"exact": value} constraint.
|
|
450
|
+
exact_value: Any | None = None
|
|
451
|
+
allow_prefix = True
|
|
452
|
+
for ops in combined_ops:
|
|
453
|
+
if set(ops.keys()) != {"exact"}:
|
|
454
|
+
allow_prefix = False
|
|
455
|
+
break
|
|
456
|
+
candidate = ops.get("exact")
|
|
457
|
+
if candidate is None:
|
|
458
|
+
allow_prefix = False
|
|
459
|
+
break
|
|
460
|
+
if exact_value is not None and candidate != exact_value:
|
|
461
|
+
# Contradictory exact filters mean no items can match.
|
|
462
|
+
logger.warning(f"Contradictory exact filters for field '{pk}': {exact_value} != {candidate}")
|
|
463
|
+
return ()
|
|
464
|
+
exact_value = candidate
|
|
465
|
+
|
|
466
|
+
if not allow_prefix:
|
|
467
|
+
break
|
|
468
|
+
|
|
469
|
+
value = exact_value
|
|
470
|
+
if value is None:
|
|
471
|
+
break
|
|
472
|
+
pk_values_prefix.append(value)
|
|
473
|
+
|
|
474
|
+
if not pk_values_prefix:
|
|
475
|
+
return self._iter_items(filters=filters, must_filters=must_filters, filter_logic=filter_logic)
|
|
476
|
+
|
|
477
|
+
try:
|
|
478
|
+
if len(pk_values_prefix) == len(self._primary_keys):
|
|
479
|
+
# All primary keys specified -> at most a single item.
|
|
480
|
+
parent, final_key = self._locate_node(pk_values_prefix, create_missing=False)
|
|
481
|
+
single_item = parent.get(final_key)
|
|
482
|
+
if isinstance(single_item, self._item_type) and _item_matches_filters(
|
|
483
|
+
single_item,
|
|
484
|
+
filters,
|
|
485
|
+
filter_logic,
|
|
486
|
+
must_filters,
|
|
487
|
+
):
|
|
488
|
+
return (single_item,)
|
|
489
|
+
return ()
|
|
490
|
+
else:
|
|
491
|
+
# Prefix of primary keys specified -> iterate only the subtree below that prefix.
|
|
492
|
+
parent, final_key = self._locate_node(pk_values_prefix, create_missing=False)
|
|
493
|
+
subtree = parent.get(final_key)
|
|
494
|
+
if isinstance(subtree, dict):
|
|
495
|
+
return self._iter_items(
|
|
496
|
+
subtree, # type: ignore
|
|
497
|
+
filters=filters,
|
|
498
|
+
must_filters=must_filters,
|
|
499
|
+
filter_logic=filter_logic,
|
|
500
|
+
)
|
|
501
|
+
return ()
|
|
502
|
+
except KeyError:
|
|
503
|
+
# No items exist for this primary-key prefix.
|
|
504
|
+
return ()
|
|
505
|
+
|
|
506
|
+
@tracked("query")
|
|
507
|
+
async def query(
|
|
508
|
+
self,
|
|
509
|
+
filter: Optional[FilterOptions] = None,
|
|
510
|
+
sort: Optional[SortOptions] = None,
|
|
511
|
+
limit: int = -1,
|
|
512
|
+
offset: int = 0,
|
|
513
|
+
) -> PaginatedResult[T]:
|
|
514
|
+
"""Query the collection with filters, sort order, and pagination.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
filter: Mapping of field name to operator dict along with the optional `_aggregate` logic.
|
|
518
|
+
sort: Options describing which field to sort by and in which order.
|
|
519
|
+
limit: Max number of items to return. Use -1 for "no limit".
|
|
520
|
+
offset: Number of items to skip from the start of the *matching* items.
|
|
521
|
+
"""
|
|
522
|
+
filters, must_filters, filter_logic = normalize_filter_options(filter)
|
|
523
|
+
sort_by, sort_order = resolve_sort_options(sort)
|
|
524
|
+
items_iter: Iterable[T] = self._iter_matching_items(filters, must_filters, filter_logic)
|
|
525
|
+
|
|
526
|
+
# No sorting: stream through items and apply pagination on the fly.
|
|
527
|
+
if not sort_by:
|
|
528
|
+
matched_items: List[T] = []
|
|
529
|
+
total_matched = 0
|
|
530
|
+
|
|
531
|
+
for item in items_iter:
|
|
532
|
+
# Count every match for 'total'
|
|
533
|
+
total_matched += 1
|
|
534
|
+
|
|
535
|
+
# Apply offset/limit window
|
|
536
|
+
if total_matched <= offset:
|
|
537
|
+
continue
|
|
538
|
+
if limit != -1 and len(matched_items) >= limit:
|
|
539
|
+
# Still need to finish iteration to get accurate total_matched.
|
|
540
|
+
continue
|
|
541
|
+
|
|
542
|
+
matched_items.append(item)
|
|
543
|
+
|
|
544
|
+
return PaginatedResult(
|
|
545
|
+
items=matched_items,
|
|
546
|
+
limit=limit,
|
|
547
|
+
offset=offset,
|
|
548
|
+
total=total_matched,
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
# With sorting: we must materialize all matching items to sort them.
|
|
552
|
+
all_matches: List[T] = list(items_iter)
|
|
553
|
+
|
|
554
|
+
total_matched = len(all_matches)
|
|
555
|
+
reverse = sort_order == "desc"
|
|
556
|
+
all_matches.sort(key=lambda x: _get_sort_value(x, sort_by), reverse=reverse)
|
|
557
|
+
|
|
558
|
+
if limit == -1:
|
|
559
|
+
paginated_items = all_matches[offset:]
|
|
560
|
+
else:
|
|
561
|
+
paginated_items = all_matches[offset : offset + limit]
|
|
562
|
+
|
|
563
|
+
return PaginatedResult(
|
|
564
|
+
items=paginated_items,
|
|
565
|
+
limit=limit,
|
|
566
|
+
offset=offset,
|
|
567
|
+
total=total_matched,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
@tracked("get")
|
|
571
|
+
async def get(
|
|
572
|
+
self,
|
|
573
|
+
filter: Optional[FilterOptions] = None,
|
|
574
|
+
sort: Optional[SortOptions] = None,
|
|
575
|
+
) -> Optional[T]:
|
|
576
|
+
"""Return the first (or best-sorted) item that matches the given filters, or None."""
|
|
577
|
+
filters, must_filters, filter_logic = normalize_filter_options(filter)
|
|
578
|
+
sort_by, sort_order = resolve_sort_options(sort)
|
|
579
|
+
items_iter: Iterable[T] = self._iter_matching_items(filters, must_filters, filter_logic)
|
|
580
|
+
|
|
581
|
+
if not sort_by:
|
|
582
|
+
# Just return the first matching item, if any.
|
|
583
|
+
for item in items_iter:
|
|
584
|
+
return item
|
|
585
|
+
return None
|
|
586
|
+
|
|
587
|
+
# Single-pass min/max according to sort_order.
|
|
588
|
+
best_item: Optional[T] = None
|
|
589
|
+
best_key: Any = None
|
|
590
|
+
|
|
591
|
+
for item in items_iter:
|
|
592
|
+
key = _get_sort_value(item, sort_by)
|
|
593
|
+
if best_item is None:
|
|
594
|
+
best_item = item
|
|
595
|
+
best_key = key
|
|
596
|
+
continue
|
|
597
|
+
|
|
598
|
+
if sort_order == "asc":
|
|
599
|
+
if key < best_key:
|
|
600
|
+
best_item, best_key = item, key
|
|
601
|
+
else:
|
|
602
|
+
if key > best_key:
|
|
603
|
+
best_item, best_key = item, key
|
|
604
|
+
|
|
605
|
+
return best_item
|
|
606
|
+
|
|
607
|
+
@tracked("insert")
|
|
608
|
+
async def insert(self, items: Sequence[T]) -> None:
|
|
609
|
+
"""Insert the given items.
|
|
610
|
+
|
|
611
|
+
Raises:
|
|
612
|
+
DuplicatedPrimaryKeyError: If any item with the same primary keys already exists.
|
|
613
|
+
"""
|
|
614
|
+
seen_keys: set[Tuple[Any, ...]] = set()
|
|
615
|
+
prepared: List[T] = []
|
|
616
|
+
for item in items:
|
|
617
|
+
self._ensure_item_type(item)
|
|
618
|
+
key_values = self._extract_primary_key_values(item)
|
|
619
|
+
if key_values in seen_keys:
|
|
620
|
+
raise DuplicatedPrimaryKeyError(
|
|
621
|
+
f"Insert payload contains duplicated primary key(s): {self._render_key_values(key_values)}"
|
|
622
|
+
)
|
|
623
|
+
seen_keys.add(key_values)
|
|
624
|
+
prepared.append(item)
|
|
625
|
+
|
|
626
|
+
for item in prepared:
|
|
627
|
+
self._mutate_single(item, mode="insert")
|
|
628
|
+
|
|
629
|
+
@tracked("update")
|
|
630
|
+
async def update(self, items: Sequence[T], update_fields: Sequence[str] | None = None) -> Sequence[T]:
|
|
631
|
+
"""Update the given items.
|
|
632
|
+
|
|
633
|
+
Raises:
|
|
634
|
+
ValueError: If any item with the given primary keys does not exist.
|
|
635
|
+
"""
|
|
636
|
+
updated_items: List[T] = []
|
|
637
|
+
for item in items:
|
|
638
|
+
updated = self._mutate_single(item, mode="update", update_fields=update_fields)
|
|
639
|
+
if updated is None:
|
|
640
|
+
raise RuntimeError(f"_mutate_single returned None for item {item}. This should never happen.")
|
|
641
|
+
updated_items.append(updated)
|
|
642
|
+
return updated_items
|
|
643
|
+
|
|
644
|
+
@tracked("upsert")
|
|
645
|
+
async def upsert(self, items: Sequence[T], update_fields: Sequence[str] | None = None) -> Sequence[T]:
|
|
646
|
+
"""Upsert the given items (insert if missing, otherwise update)."""
|
|
647
|
+
upserted_items: List[T] = []
|
|
648
|
+
for item in items:
|
|
649
|
+
upserted = self._mutate_single(item, mode="upsert", update_fields=update_fields)
|
|
650
|
+
if upserted is None:
|
|
651
|
+
raise RuntimeError(f"_mutate_single returned None for item {item}. This should never happen.")
|
|
652
|
+
upserted_items.append(upserted)
|
|
653
|
+
return upserted_items
|
|
654
|
+
|
|
655
|
+
@tracked("delete")
|
|
656
|
+
async def delete(self, items: Sequence[T]) -> None:
|
|
657
|
+
"""Delete the given items.
|
|
658
|
+
|
|
659
|
+
Raises:
|
|
660
|
+
ValueError: If any item with the given primary keys does not exist.
|
|
661
|
+
"""
|
|
662
|
+
# We use a two-phase approach to avoid partial deletion if one fails:
|
|
663
|
+
# first compute key_values to validate, then perform deletions.
|
|
664
|
+
for item in items:
|
|
665
|
+
# _mutate_single will validate existence and update size.
|
|
666
|
+
self._mutate_single(item, mode="delete")
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
class DequeBasedQueue(Queue[T]):
|
|
670
|
+
"""Queue implementation backed by collections.deque.
|
|
671
|
+
|
|
672
|
+
Provides O(1) amortized enqueue (append) and dequeue (popleft).
|
|
673
|
+
"""
|
|
674
|
+
|
|
675
|
+
def __init__(
|
|
676
|
+
self,
|
|
677
|
+
item_type: Type[T],
|
|
678
|
+
items: Optional[Sequence[T]] = None,
|
|
679
|
+
id: Optional[str] = None,
|
|
680
|
+
tracker: Optional[MetricsBackend] = None,
|
|
681
|
+
):
|
|
682
|
+
super().__init__(tracker=tracker)
|
|
683
|
+
self._items: Deque[T] = deque()
|
|
684
|
+
self._item_type: Type[T] = item_type
|
|
685
|
+
self._id = id if id is not None else str(uuid.uuid4())
|
|
686
|
+
if items:
|
|
687
|
+
self._items.extend(items)
|
|
688
|
+
|
|
689
|
+
def item_type(self) -> Type[T]:
|
|
690
|
+
return self._item_type
|
|
691
|
+
|
|
692
|
+
@property
|
|
693
|
+
def collection_name(self) -> str:
|
|
694
|
+
return self._id
|
|
695
|
+
|
|
696
|
+
def __repr__(self) -> str:
|
|
697
|
+
return f"<{self.__class__.__name__}[{self.item_type().__name__}] ({len(self._items)})>"
|
|
698
|
+
|
|
699
|
+
@tracked("has")
|
|
700
|
+
async def has(self, item: T) -> bool:
|
|
701
|
+
if not isinstance(item, self._item_type):
|
|
702
|
+
raise TypeError(f"Expected item of type {self._item_type.__name__}, got {type(item).__name__}")
|
|
703
|
+
return item in self._items
|
|
704
|
+
|
|
705
|
+
@tracked("enqueue")
|
|
706
|
+
async def enqueue(self, items: Sequence[T]) -> Sequence[T]:
|
|
707
|
+
for item in items:
|
|
708
|
+
if not isinstance(item, self._item_type):
|
|
709
|
+
raise TypeError(f"Expected item of type {self._item_type.__name__}, got {type(item).__name__}")
|
|
710
|
+
self._items.append(item)
|
|
711
|
+
return items
|
|
712
|
+
|
|
713
|
+
@tracked("dequeue")
|
|
714
|
+
async def dequeue(self, limit: int = 1) -> Sequence[T]:
|
|
715
|
+
if limit <= 0:
|
|
716
|
+
return []
|
|
717
|
+
out: List[T] = []
|
|
718
|
+
for _ in range(min(limit, len(self._items))):
|
|
719
|
+
out.append(self._items.popleft())
|
|
720
|
+
return out
|
|
721
|
+
|
|
722
|
+
@tracked("peek")
|
|
723
|
+
async def peek(self, limit: int = 1) -> Sequence[T]:
|
|
724
|
+
if limit <= 0:
|
|
725
|
+
return []
|
|
726
|
+
result: List[T] = []
|
|
727
|
+
count = min(limit, len(self._items))
|
|
728
|
+
for idx, item in enumerate(self._items):
|
|
729
|
+
if idx >= count:
|
|
730
|
+
break
|
|
731
|
+
result.append(item)
|
|
732
|
+
return result
|
|
733
|
+
|
|
734
|
+
@tracked("size")
|
|
735
|
+
async def size(self) -> int:
|
|
736
|
+
return len(self._items)
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
class DictBasedKeyValue(KeyValue[K, V]):
|
|
740
|
+
"""KeyValue implementation backed by a plain dictionary."""
|
|
741
|
+
|
|
742
|
+
def __init__(
|
|
743
|
+
self, data: Optional[Mapping[K, V]] = None, id: Optional[str] = None, tracker: Optional[MetricsBackend] = None
|
|
744
|
+
):
|
|
745
|
+
super().__init__(tracker=tracker)
|
|
746
|
+
self._values: Dict[K, V] = dict(data) if data else {}
|
|
747
|
+
self._id = id if id is not None else str(uuid.uuid4())
|
|
748
|
+
|
|
749
|
+
@property
|
|
750
|
+
def collection_name(self) -> str:
|
|
751
|
+
return self._id
|
|
752
|
+
|
|
753
|
+
@tracked("has")
|
|
754
|
+
async def has(self, key: K) -> bool:
|
|
755
|
+
return key in self._values
|
|
756
|
+
|
|
757
|
+
@tracked("get")
|
|
758
|
+
async def get(self, key: K, default: V | None = None) -> V | None:
|
|
759
|
+
return self._values.get(key, default)
|
|
760
|
+
|
|
761
|
+
@tracked("set")
|
|
762
|
+
async def set(self, key: K, value: V) -> None:
|
|
763
|
+
self._values[key] = value
|
|
764
|
+
|
|
765
|
+
@tracked("inc")
|
|
766
|
+
async def inc(self, key: K, amount: V) -> V:
|
|
767
|
+
assert ensure_numeric(amount, description="amount")
|
|
768
|
+
if key in self._values:
|
|
769
|
+
current_value = self._values[key]
|
|
770
|
+
assert ensure_numeric(current_value, description=f"value for key {key!r}")
|
|
771
|
+
new_value = cast(V, current_value + amount)
|
|
772
|
+
self._values[key] = new_value
|
|
773
|
+
else:
|
|
774
|
+
new_value = amount
|
|
775
|
+
self._values[key] = new_value
|
|
776
|
+
return new_value
|
|
777
|
+
|
|
778
|
+
@tracked("chmax")
|
|
779
|
+
async def chmax(self, key: K, value: V) -> V:
|
|
780
|
+
assert ensure_numeric(value, description="value")
|
|
781
|
+
if key in self._values:
|
|
782
|
+
current_value = self._values[key]
|
|
783
|
+
assert ensure_numeric(current_value, description=f"value for key {key!r}")
|
|
784
|
+
if value > current_value:
|
|
785
|
+
self._values[key] = value
|
|
786
|
+
return value
|
|
787
|
+
return current_value
|
|
788
|
+
else:
|
|
789
|
+
self._values[key] = value
|
|
790
|
+
return value
|
|
791
|
+
|
|
792
|
+
@tracked("pop")
|
|
793
|
+
async def pop(self, key: K, default: V | None = None) -> V | None:
|
|
794
|
+
return self._values.pop(key, default)
|
|
795
|
+
|
|
796
|
+
@tracked("size")
|
|
797
|
+
async def size(self) -> int:
|
|
798
|
+
return len(self._values)
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
class InMemoryLightningCollections(LightningCollections):
|
|
802
|
+
"""In-memory implementation of LightningCollections using Python data structures.
|
|
803
|
+
|
|
804
|
+
Serves as the storage base for [`InMemoryLightningStore`][mantisdk.InMemoryLightningStore].
|
|
805
|
+
"""
|
|
806
|
+
|
|
807
|
+
def __init__(self, lock_type: Literal["thread", "asyncio"], tracker: MetricsBackend | None = None):
|
|
808
|
+
super().__init__(tracker=tracker)
|
|
809
|
+
self._lock: Mapping[AtomicLabels, _LoopAwareAsyncLock | _ThreadSafeAsyncLock] = {
|
|
810
|
+
"rollouts": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
|
|
811
|
+
"attempts": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
|
|
812
|
+
"spans": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
|
|
813
|
+
"resources": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
|
|
814
|
+
"workers": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
|
|
815
|
+
"rollout_queue": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
|
|
816
|
+
"span_sequence_ids": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
|
|
817
|
+
"generic": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
|
|
818
|
+
}
|
|
819
|
+
self._rollouts = ListBasedCollection(
|
|
820
|
+
items=[], item_type=Rollout, primary_keys=["rollout_id"], id="rollouts", tracker=tracker
|
|
821
|
+
)
|
|
822
|
+
self._attempts = ListBasedCollection(
|
|
823
|
+
items=[], item_type=Attempt, primary_keys=["rollout_id", "attempt_id"], id="attempts", tracker=tracker
|
|
824
|
+
)
|
|
825
|
+
self._spans = ListBasedCollection(
|
|
826
|
+
items=[], item_type=Span, primary_keys=["rollout_id", "attempt_id", "span_id"], id="spans", tracker=tracker
|
|
827
|
+
)
|
|
828
|
+
self._resources = ListBasedCollection(
|
|
829
|
+
items=[], item_type=ResourcesUpdate, primary_keys=["resources_id"], id="resources", tracker=tracker
|
|
830
|
+
)
|
|
831
|
+
self._workers = ListBasedCollection(
|
|
832
|
+
items=[], item_type=Worker, primary_keys=["worker_id"], id="workers", tracker=tracker
|
|
833
|
+
)
|
|
834
|
+
self._rollout_queue = DequeBasedQueue(items=[], item_type=str, id="rollout_queue", tracker=tracker)
|
|
835
|
+
self._span_sequence_ids = DictBasedKeyValue[str, int](
|
|
836
|
+
data={}, id="span_sequence_ids", tracker=tracker
|
|
837
|
+
) # rollout_id -> sequence_id
|
|
838
|
+
|
|
839
|
+
@property
|
|
840
|
+
def collection_name(self) -> str:
|
|
841
|
+
return "router"
|
|
842
|
+
|
|
843
|
+
@property
|
|
844
|
+
def rollouts(self) -> ListBasedCollection[Rollout]:
|
|
845
|
+
return self._rollouts
|
|
846
|
+
|
|
847
|
+
@property
|
|
848
|
+
def attempts(self) -> ListBasedCollection[Attempt]:
|
|
849
|
+
return self._attempts
|
|
850
|
+
|
|
851
|
+
@property
|
|
852
|
+
def spans(self) -> ListBasedCollection[Span]:
|
|
853
|
+
return self._spans
|
|
854
|
+
|
|
855
|
+
@property
|
|
856
|
+
def resources(self) -> ListBasedCollection[ResourcesUpdate]:
|
|
857
|
+
return self._resources
|
|
858
|
+
|
|
859
|
+
@property
|
|
860
|
+
def workers(self) -> ListBasedCollection[Worker]:
|
|
861
|
+
return self._workers
|
|
862
|
+
|
|
863
|
+
@property
|
|
864
|
+
def rollout_queue(self) -> DequeBasedQueue[str]:
|
|
865
|
+
return self._rollout_queue
|
|
866
|
+
|
|
867
|
+
@property
|
|
868
|
+
def span_sequence_ids(self) -> DictBasedKeyValue[str, int]:
|
|
869
|
+
return self._span_sequence_ids
|
|
870
|
+
|
|
871
|
+
@asynccontextmanager
|
|
872
|
+
async def atomic(
|
|
873
|
+
self,
|
|
874
|
+
*,
|
|
875
|
+
mode: AtomicMode = "rw",
|
|
876
|
+
snapshot: bool = False,
|
|
877
|
+
labels: Optional[Sequence[AtomicLabels]] = None,
|
|
878
|
+
**kwargs: Any,
|
|
879
|
+
):
|
|
880
|
+
"""In-memory collections apply a lock outside. It doesn't need to manipulate the collections inside.
|
|
881
|
+
|
|
882
|
+
Skip the locking if mode is "r" and snapshot is False.
|
|
883
|
+
|
|
884
|
+
This collection implementation does NOT support rollback / commit.
|
|
885
|
+
"""
|
|
886
|
+
if mode == "r" and not snapshot:
|
|
887
|
+
yield self
|
|
888
|
+
return
|
|
889
|
+
if not labels:
|
|
890
|
+
# If no labels are provided, use all locks.
|
|
891
|
+
labels = list(self._lock.keys())
|
|
892
|
+
|
|
893
|
+
# IMPORTANT: Sort the labels to ensure consistent locking order.
|
|
894
|
+
# This is necessary to avoid deadlocks when multiple threads/coroutines
|
|
895
|
+
# are trying to acquire the same locks in different orders.
|
|
896
|
+
labels = sorted(labels)
|
|
897
|
+
|
|
898
|
+
async with self.tracking_context(operation="atomic", collection=self.collection_name):
|
|
899
|
+
managers = [(label, self._lock[label]) for label in labels]
|
|
900
|
+
async with AsyncExitStack() as stack:
|
|
901
|
+
for label, manager in managers:
|
|
902
|
+
async with self.tracking_context(operation="lock", collection=label):
|
|
903
|
+
await stack.enter_async_context(manager)
|
|
904
|
+
yield self
|
|
905
|
+
|
|
906
|
+
@tracked("evict_spans_for_rollout")
|
|
907
|
+
async def evict_spans_for_rollout(self, rollout_id: str) -> None:
|
|
908
|
+
"""Evict all spans for a given rollout ID.
|
|
909
|
+
|
|
910
|
+
Uses private API for efficiency.
|
|
911
|
+
"""
|
|
912
|
+
self._spans._items.pop(rollout_id, []) # pyright: ignore[reportPrivateUsage]
|
|
913
|
+
|
|
914
|
+
|
|
915
|
+
class _LoopAwareAsyncLock:
|
|
916
|
+
"""Async lock that transparently rebinds to the current event loop.
|
|
917
|
+
|
|
918
|
+
The lock intentionally remains *thread-unsafe*: callers must only use it from
|
|
919
|
+
one thread at a time. If multiple threads interact with the store, each
|
|
920
|
+
thread gets its own event loop specific lock.
|
|
921
|
+
"""
|
|
922
|
+
|
|
923
|
+
def __init__(self) -> None:
|
|
924
|
+
self._locks: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Lock] = weakref.WeakKeyDictionary()
|
|
925
|
+
|
|
926
|
+
# When serializing and deserializing, we don't need to serialize the locks.
|
|
927
|
+
# Because another process will have its own set of event loops and its own lock.
|
|
928
|
+
def __getstate__(self) -> dict[str, Any]:
|
|
929
|
+
return {}
|
|
930
|
+
|
|
931
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
932
|
+
self._locks = weakref.WeakKeyDictionary()
|
|
933
|
+
|
|
934
|
+
def _get_lock_for_current_loop(self) -> asyncio.Lock:
|
|
935
|
+
loop = asyncio.get_running_loop()
|
|
936
|
+
lock = self._locks.get(loop)
|
|
937
|
+
if lock is None:
|
|
938
|
+
lock = asyncio.Lock()
|
|
939
|
+
self._locks[loop] = lock
|
|
940
|
+
return lock
|
|
941
|
+
|
|
942
|
+
async def __aenter__(self) -> asyncio.Lock:
|
|
943
|
+
lock = self._get_lock_for_current_loop()
|
|
944
|
+
await lock.acquire()
|
|
945
|
+
return lock
|
|
946
|
+
|
|
947
|
+
async def __aexit__(self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: Any) -> None:
|
|
948
|
+
loop = asyncio.get_running_loop()
|
|
949
|
+
lock = self._locks.get(loop)
|
|
950
|
+
if lock is None or not lock.locked():
|
|
951
|
+
raise RuntimeError("Lock released without being acquired")
|
|
952
|
+
lock.release()
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
class _ThreadSafeAsyncLock:
|
|
956
|
+
"""A thread lock powered by aiologic that can be used in both async and sync contexts.
|
|
957
|
+
|
|
958
|
+
aiologic claims itself to be a thread-safe asyncio lock.
|
|
959
|
+
"""
|
|
960
|
+
|
|
961
|
+
def __init__(self):
|
|
962
|
+
self._lock = aiologic.Lock()
|
|
963
|
+
|
|
964
|
+
async def __aenter__(self):
|
|
965
|
+
await self._lock.async_acquire()
|
|
966
|
+
return self
|
|
967
|
+
|
|
968
|
+
async def __aexit__(self, *args: Any, **kwargs: Any):
|
|
969
|
+
# .release() is non-blocking, so we can call it directly
|
|
970
|
+
self._lock.async_release()
|