mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Optional, Protocol, runtime_checkable
|
|
6
|
+
|
|
7
|
+
from mantisdk.types import Attempt, AttemptedRollout, NamedResources, ResourcesUpdate, Rollout, Span
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@runtime_checkable
|
|
11
|
+
class StorageListener(Protocol):
|
|
12
|
+
"""Protocol for listening to storage events.
|
|
13
|
+
|
|
14
|
+
Listeners can be attached to a LightningStore to observe state changes
|
|
15
|
+
and perform side effects (logging, tracking, etc.) without modifying
|
|
16
|
+
the core storage logic.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def capabilities(self) -> Dict[str, bool]:
|
|
21
|
+
"""Return the capabilities of the listener (e.g., {"otlp_traces": True})."""
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
def otlp_traces_endpoint(self) -> Optional[str]:
|
|
25
|
+
"""Return OTLP endpoint if supported, else None."""
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
def get_otlp_headers(self) -> Dict[str, str]:
|
|
29
|
+
"""Return OTLP headers if supported, else empty dict."""
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
async def on_job_created(self, job_id: str, project_id: Optional[str] = None) -> None:
|
|
33
|
+
"""Called when the store/job is initialized."""
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
async def on_rollout_created(self, rollout: Rollout) -> None:
|
|
37
|
+
"""Called when a rollout is created (start or enqueue)."""
|
|
38
|
+
...
|
|
39
|
+
|
|
40
|
+
async def on_rollout_updated(self, rollout: Rollout) -> None:
|
|
41
|
+
"""Called when a rollout is updated (status change, etc.)."""
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
async def on_attempt_created(self, attempt: Attempt) -> None:
|
|
45
|
+
"""Called when an attempt is created."""
|
|
46
|
+
...
|
|
47
|
+
|
|
48
|
+
async def on_attempt_updated(self, attempt: Attempt, rollout_id: str) -> None:
|
|
49
|
+
"""Called when an attempt is updated."""
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
async def on_span_created(self, span: Span) -> None:
|
|
53
|
+
"""Called when a span is added."""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
async def on_resource_registered(self, resource: ResourcesUpdate) -> None:
|
|
57
|
+
"""Called when a resource snapshot is registered/updated."""
|
|
58
|
+
...
|
mantisdk/store/memory.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import sys
|
|
8
|
+
from collections.abc import Iterable
|
|
9
|
+
from collections.abc import Mapping as MappingABC
|
|
10
|
+
from typing import (
|
|
11
|
+
TYPE_CHECKING,
|
|
12
|
+
Any,
|
|
13
|
+
Callable,
|
|
14
|
+
Counter,
|
|
15
|
+
Dict,
|
|
16
|
+
List,
|
|
17
|
+
Literal,
|
|
18
|
+
Mapping,
|
|
19
|
+
Optional,
|
|
20
|
+
Sequence,
|
|
21
|
+
Set,
|
|
22
|
+
Tuple,
|
|
23
|
+
TypeVar,
|
|
24
|
+
Union,
|
|
25
|
+
cast,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from .listener import StorageListener
|
|
30
|
+
|
|
31
|
+
import aiologic
|
|
32
|
+
from pydantic import BaseModel
|
|
33
|
+
|
|
34
|
+
from mantisdk.types import AttemptedRollout, NamedResources, PaginatedResult, ResourcesUpdate, Rollout, Span
|
|
35
|
+
from mantisdk.utils.metrics import MetricsBackend
|
|
36
|
+
|
|
37
|
+
from .base import UNSET, LightningStoreCapabilities, LightningStoreStatistics, Unset, is_finished, is_running
|
|
38
|
+
from .collection import InMemoryLightningCollections
|
|
39
|
+
from .collection_based import CollectionBasedLightningStore, tracked
|
|
40
|
+
|
|
41
|
+
T_callable = TypeVar("T_callable", bound=Callable[..., Any])
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def estimate_model_size(obj: Any) -> int:
|
|
47
|
+
"""Rough recursive size estimate for Pydantic BaseModel instances."""
|
|
48
|
+
|
|
49
|
+
if isinstance(obj, BaseModel):
|
|
50
|
+
values = cast(Iterable[Any], obj.__dict__.values())
|
|
51
|
+
return sum(estimate_model_size(value) for value in values) + sys.getsizeof(cast(object, obj))
|
|
52
|
+
if isinstance(obj, MappingABC):
|
|
53
|
+
mapping = cast(Mapping[Any, Any], obj)
|
|
54
|
+
return sum(estimate_model_size(value) for value in mapping.values()) + sys.getsizeof(cast(object, obj))
|
|
55
|
+
if isinstance(obj, (list, tuple, set)):
|
|
56
|
+
iterable = cast(Iterable[Any], obj)
|
|
57
|
+
return sum(estimate_model_size(value) for value in iterable) + sys.getsizeof(cast(object, obj))
|
|
58
|
+
return sys.getsizeof(cast(object, obj))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _detect_total_memory_bytes() -> int:
|
|
62
|
+
"""Best-effort detection of the total available system memory in bytes."""
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
import psutil
|
|
66
|
+
|
|
67
|
+
return int(psutil.virtual_memory().total)
|
|
68
|
+
except ImportError:
|
|
69
|
+
# Fallback to 8GB if memory cannot be detected.
|
|
70
|
+
logger.error("psutil is not installed. Falling back to 8GB of memory in total.")
|
|
71
|
+
return 8 * 1024**3
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class InMemoryLightningStore(CollectionBasedLightningStore[InMemoryLightningCollections]):
|
|
75
|
+
"""
|
|
76
|
+
In-memory implementation of LightningStore using Python data structures.
|
|
77
|
+
Thread-safe and async-compatible but data is not persistent.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
thread_safe: Whether the store is thread-safe.
|
|
81
|
+
eviction_memory_threshold: The threshold for evicting spans in bytes.
|
|
82
|
+
By default, it's 70% of the total VRAM available.
|
|
83
|
+
safe_memory_threshold: The threshold for safe memory usage in bytes.
|
|
84
|
+
By default, it's 80% of the eviction threshold.
|
|
85
|
+
span_size_estimator: A function to estimate the size of a span in bytes.
|
|
86
|
+
By default, it's a simple size estimator that uses sys.getsizeof.
|
|
87
|
+
tracker: The metrics tracker to use.
|
|
88
|
+
scan_debounce_seconds: The debounce time for the scan for unhealthy rollouts.
|
|
89
|
+
Set to 0 to disable debouncing.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
*,
|
|
95
|
+
thread_safe: bool = False,
|
|
96
|
+
eviction_memory_threshold: float | int | None = None,
|
|
97
|
+
safe_memory_threshold: float | int | None = None,
|
|
98
|
+
span_size_estimator: Callable[[Span], int] | None = None,
|
|
99
|
+
tracker: MetricsBackend | None = None,
|
|
100
|
+
scan_debounce_seconds: float = 10.0,
|
|
101
|
+
listeners: Optional[Sequence["StorageListener"]] = None,
|
|
102
|
+
):
|
|
103
|
+
super().__init__(
|
|
104
|
+
collections=InMemoryLightningCollections(lock_type="thread" if thread_safe else "asyncio", tracker=tracker),
|
|
105
|
+
tracker=tracker,
|
|
106
|
+
scan_debounce_seconds=scan_debounce_seconds,
|
|
107
|
+
listeners=listeners,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
self._thread_safe = thread_safe
|
|
111
|
+
self._start_time_by_rollout: Dict[str, float] = {}
|
|
112
|
+
self._span_bytes_by_rollout: Dict[str, int] = Counter()
|
|
113
|
+
self._total_span_bytes: int = 0
|
|
114
|
+
self._evicted_rollout_span_sets: Set[str] = set()
|
|
115
|
+
|
|
116
|
+
self._memory_capacity_bytes = _detect_total_memory_bytes()
|
|
117
|
+
if self._memory_capacity_bytes <= 0:
|
|
118
|
+
raise ValueError("Detected memory capacity must be positive")
|
|
119
|
+
|
|
120
|
+
self._eviction_threshold_bytes = self._resolve_memory_threshold(
|
|
121
|
+
eviction_memory_threshold,
|
|
122
|
+
default_ratio=0.7,
|
|
123
|
+
capacity_bytes=self._memory_capacity_bytes,
|
|
124
|
+
name="eviction_memory_threshold",
|
|
125
|
+
minimum=1,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if safe_memory_threshold is None:
|
|
129
|
+
safe_memory_threshold = max(int(self._eviction_threshold_bytes * 0.8), 0)
|
|
130
|
+
|
|
131
|
+
self._safe_threshold_bytes = self._resolve_memory_threshold(
|
|
132
|
+
safe_memory_threshold,
|
|
133
|
+
default_ratio=self._eviction_threshold_bytes / self._memory_capacity_bytes,
|
|
134
|
+
capacity_bytes=self._memory_capacity_bytes,
|
|
135
|
+
name="safe_memory_threshold",
|
|
136
|
+
minimum=0,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if not (0 <= self._safe_threshold_bytes < self._eviction_threshold_bytes):
|
|
140
|
+
raise ValueError("safe_memory_threshold must be smaller than eviction_memory_threshold")
|
|
141
|
+
self._custom_span_size_estimator = span_size_estimator
|
|
142
|
+
|
|
143
|
+
# Completion tracking for wait_for_rollouts (cross-loop safe)
|
|
144
|
+
self._completion_events: Dict[str, aiologic.Event] = {}
|
|
145
|
+
|
|
146
|
+
# Running rollouts cache, including preparing and running rollouts
|
|
147
|
+
self._running_rollout_ids: Set[str] = set()
|
|
148
|
+
|
|
149
|
+
# Caches the latest resources ID.
|
|
150
|
+
self._latest_resources_id: Union[str, None, Unset] = UNSET
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def capabilities(self) -> LightningStoreCapabilities:
|
|
154
|
+
"""Return the capabilities of the store.
|
|
155
|
+
|
|
156
|
+
Merges base store capabilities with listener capabilities (e.g., InsightTracker).
|
|
157
|
+
"""
|
|
158
|
+
base_caps = LightningStoreCapabilities(
|
|
159
|
+
thread_safe=self._thread_safe,
|
|
160
|
+
async_safe=True,
|
|
161
|
+
zero_copy=False,
|
|
162
|
+
otlp_traces=False,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Merge capabilities from listeners (e.g., InsightTracker provides otlp_traces=True)
|
|
166
|
+
for listener in self.listeners:
|
|
167
|
+
if hasattr(listener, "capabilities"):
|
|
168
|
+
base_caps.update(listener.capabilities)
|
|
169
|
+
|
|
170
|
+
return base_caps
|
|
171
|
+
|
|
172
|
+
async def statistics(self) -> LightningStoreStatistics:
|
|
173
|
+
"""Return the statistics of the store."""
|
|
174
|
+
return {
|
|
175
|
+
**(await super().statistics()),
|
|
176
|
+
"total_span_bytes": self._total_span_bytes,
|
|
177
|
+
"eviction_threshold_bytes": self._eviction_threshold_bytes,
|
|
178
|
+
"safe_threshold_bytes": self._safe_threshold_bytes,
|
|
179
|
+
"memory_capacity_bytes": self._memory_capacity_bytes,
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
@tracked("wait_for_rollout")
|
|
183
|
+
async def wait_for_rollout(self, rollout_id: str, timeout: Optional[float] = None) -> Optional[Rollout]:
|
|
184
|
+
"""Wait for a specific rollout to complete with a timeout."""
|
|
185
|
+
async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["rollouts"]) as collections:
|
|
186
|
+
rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
|
|
187
|
+
if rollout and is_finished(rollout):
|
|
188
|
+
return rollout
|
|
189
|
+
|
|
190
|
+
if timeout is not None and timeout <= 0:
|
|
191
|
+
return None
|
|
192
|
+
|
|
193
|
+
# If not completed and we have an event, wait for completion
|
|
194
|
+
if rollout_id in self._completion_events:
|
|
195
|
+
evt = self._completion_events[rollout_id]
|
|
196
|
+
|
|
197
|
+
# Wait for the event with proper timeout handling
|
|
198
|
+
# evt.wait() returns True if event was set, False if timeout occurred
|
|
199
|
+
if timeout is None:
|
|
200
|
+
# Wait indefinitely by polling with finite timeouts
|
|
201
|
+
# This allows threads to exit cleanly on shutdown
|
|
202
|
+
while True:
|
|
203
|
+
result = await asyncio.to_thread(evt.wait, 10.0) # Poll every 10 seconds
|
|
204
|
+
if result: # Event was set
|
|
205
|
+
break
|
|
206
|
+
# Loop and check again (continues indefinitely since timeout=None)
|
|
207
|
+
else:
|
|
208
|
+
# Wait with the specified timeout
|
|
209
|
+
result = await asyncio.to_thread(evt.wait, timeout)
|
|
210
|
+
|
|
211
|
+
# If event was set (not timeout), check if rollout is finished
|
|
212
|
+
if result:
|
|
213
|
+
async with self.collections.atomic(
|
|
214
|
+
mode="r", snapshot=self._read_snapshot, labels=["rollouts"]
|
|
215
|
+
) as collections:
|
|
216
|
+
rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
|
|
217
|
+
if rollout and is_finished(rollout):
|
|
218
|
+
return rollout
|
|
219
|
+
|
|
220
|
+
return None
|
|
221
|
+
|
|
222
|
+
@tracked("add_resources_inmemory")
|
|
223
|
+
async def add_resources(self, resources: NamedResources) -> ResourcesUpdate:
|
|
224
|
+
ret = await super().add_resources(resources)
|
|
225
|
+
async with self.collections.atomic(mode="rw", snapshot=self._read_snapshot, labels=["resources"]):
|
|
226
|
+
self._latest_resources_id = ret.resources_id
|
|
227
|
+
return ret
|
|
228
|
+
|
|
229
|
+
@tracked("update_resources_inmemory")
|
|
230
|
+
async def update_resources(self, resources_id: str, resources: NamedResources) -> ResourcesUpdate:
|
|
231
|
+
ret = await super().update_resources(resources_id, resources)
|
|
232
|
+
async with self.collections.atomic(mode="rw", snapshot=self._read_snapshot, labels=["resources"]):
|
|
233
|
+
self._latest_resources_id = ret.resources_id
|
|
234
|
+
return ret
|
|
235
|
+
|
|
236
|
+
@tracked("_post_update_rollout_inmemory")
|
|
237
|
+
async def _post_update_rollout(
|
|
238
|
+
self, rollouts: Sequence[Tuple[Rollout, Sequence[str]]], skip_enqueue: bool = False
|
|
239
|
+
) -> None:
|
|
240
|
+
"""Update the running rollout ids set when the rollout updates."""
|
|
241
|
+
await super()._post_update_rollout(rollouts, skip_enqueue=skip_enqueue)
|
|
242
|
+
async with self.collections.atomic(mode="rw", snapshot=self._read_snapshot, labels=["rollouts"]):
|
|
243
|
+
for rollout, _ in rollouts:
|
|
244
|
+
if is_running(rollout):
|
|
245
|
+
self._running_rollout_ids.add(rollout.rollout_id)
|
|
246
|
+
else:
|
|
247
|
+
self._running_rollout_ids.discard(rollout.rollout_id)
|
|
248
|
+
|
|
249
|
+
if is_finished(rollout):
|
|
250
|
+
self._completion_events.setdefault(rollout.rollout_id, aiologic.Event())
|
|
251
|
+
self._completion_events[rollout.rollout_id].set()
|
|
252
|
+
else:
|
|
253
|
+
self._completion_events.setdefault(rollout.rollout_id, aiologic.Event())
|
|
254
|
+
# Rollout status can never transition from finished to running (unlike attempt)
|
|
255
|
+
# so we don't need to clear the completion event even in case of retrying.
|
|
256
|
+
|
|
257
|
+
if rollout.rollout_id not in self._start_time_by_rollout:
|
|
258
|
+
self._start_time_by_rollout[rollout.rollout_id] = rollout.start_time
|
|
259
|
+
|
|
260
|
+
@tracked("_unlocked_query_rollouts_by_rollout_ids")
|
|
261
|
+
async def _unlocked_query_rollouts_by_rollout_ids(
|
|
262
|
+
self, collections: InMemoryLightningCollections, rollout_ids: Sequence[str]
|
|
263
|
+
) -> List[Rollout]:
|
|
264
|
+
"""Always use exact. This is faster than within filter for in-memory store."""
|
|
265
|
+
if len(rollout_ids) == 0:
|
|
266
|
+
return []
|
|
267
|
+
|
|
268
|
+
rollouts = [await collections.rollouts.get({"rollout_id": {"exact": rollout_id}}) for rollout_id in rollout_ids]
|
|
269
|
+
return [rollout for rollout in rollouts if rollout is not None]
|
|
270
|
+
|
|
271
|
+
@tracked("_unlocked_get_running_rollouts")
|
|
272
|
+
async def _unlocked_get_running_rollouts(self, collections: InMemoryLightningCollections) -> List[AttemptedRollout]:
|
|
273
|
+
"""Accelerated version of `_unlocked_get_running_rollouts` for in-memory store. Used for healthcheck."""
|
|
274
|
+
async with self.collections.atomic(
|
|
275
|
+
mode="r", snapshot=self._read_snapshot, labels=["rollouts", "attempts"]
|
|
276
|
+
) as collections:
|
|
277
|
+
rollouts = await self._unlocked_query_rollouts_by_rollout_ids(collections, list(self._running_rollout_ids))
|
|
278
|
+
running_rollouts: List[AttemptedRollout] = []
|
|
279
|
+
for rollout in rollouts:
|
|
280
|
+
latest_attempt = await collections.attempts.get(
|
|
281
|
+
filter={"rollout_id": {"exact": rollout.rollout_id}},
|
|
282
|
+
sort={"name": "sequence_id", "order": "desc"},
|
|
283
|
+
)
|
|
284
|
+
if not latest_attempt:
|
|
285
|
+
# The rollout is running but has no attempts, this should not happen
|
|
286
|
+
logger.error(f"Rollout {rollout.rollout_id} is running but has no attempts")
|
|
287
|
+
continue
|
|
288
|
+
running_rollouts.append(AttemptedRollout(**rollout.model_dump(), attempt=latest_attempt))
|
|
289
|
+
return running_rollouts
|
|
290
|
+
|
|
291
|
+
@tracked("query_spans_inmemory") # Since this method calls super, we need to track it separately
|
|
292
|
+
async def query_spans(
|
|
293
|
+
self,
|
|
294
|
+
rollout_id: str,
|
|
295
|
+
attempt_id: str | Literal["latest"] | None = None,
|
|
296
|
+
**kwargs: Any,
|
|
297
|
+
) -> PaginatedResult[Span]:
|
|
298
|
+
if rollout_id in self._evicted_rollout_span_sets:
|
|
299
|
+
raise RuntimeError(f"Spans for rollout {rollout_id} have been evicted")
|
|
300
|
+
return await super().query_spans(rollout_id, attempt_id, **kwargs)
|
|
301
|
+
|
|
302
|
+
@tracked("_post_add_spans")
|
|
303
|
+
async def _post_add_spans(self, spans: Sequence[Span], rollout_id: str, attempt_id: str) -> None:
|
|
304
|
+
"""In-memory store needs to maintain the span data in memory, and evict spans when memory is low."""
|
|
305
|
+
|
|
306
|
+
await super()._post_add_spans(spans, rollout_id, attempt_id)
|
|
307
|
+
async with self.collections.atomic(
|
|
308
|
+
mode="rw", snapshot=self._read_snapshot, labels=["rollouts", "spans"]
|
|
309
|
+
) as collections:
|
|
310
|
+
for span in spans:
|
|
311
|
+
await self._account_span_size(span)
|
|
312
|
+
await self._maybe_evict_spans(collections)
|
|
313
|
+
|
|
314
|
+
@tracked("_get_latest_resources_inmemory")
|
|
315
|
+
async def _get_latest_resources(self) -> Optional[ResourcesUpdate]:
|
|
316
|
+
if isinstance(self._latest_resources_id, Unset):
|
|
317
|
+
return await super()._get_latest_resources()
|
|
318
|
+
if self._latest_resources_id is not None:
|
|
319
|
+
async with self.collections.atomic(
|
|
320
|
+
mode="r", snapshot=self._read_snapshot, labels=["resources"]
|
|
321
|
+
) as collections:
|
|
322
|
+
return await collections.resources.get(filter={"resources_id": {"exact": self._latest_resources_id}})
|
|
323
|
+
return None
|
|
324
|
+
|
|
325
|
+
@staticmethod
|
|
326
|
+
def _resolve_memory_threshold(
|
|
327
|
+
value: float | int | None,
|
|
328
|
+
*,
|
|
329
|
+
default_ratio: float,
|
|
330
|
+
capacity_bytes: int,
|
|
331
|
+
name: str,
|
|
332
|
+
minimum: int,
|
|
333
|
+
) -> int:
|
|
334
|
+
if value is None:
|
|
335
|
+
resolved = int(capacity_bytes * default_ratio)
|
|
336
|
+
elif isinstance(value, float):
|
|
337
|
+
if minimum == 0:
|
|
338
|
+
if not (0 <= value <= 1):
|
|
339
|
+
raise ValueError(f"{name} ratio must be between 0 and 1 inclusive")
|
|
340
|
+
else:
|
|
341
|
+
if not (0 < value <= 1):
|
|
342
|
+
raise ValueError(f"{name} ratio must be greater than 0 and at most 1")
|
|
343
|
+
resolved = int(capacity_bytes * value)
|
|
344
|
+
else:
|
|
345
|
+
value_int = value
|
|
346
|
+
if value_int < 0:
|
|
347
|
+
raise ValueError(f"{name} must be non-negative")
|
|
348
|
+
resolved = value_int
|
|
349
|
+
|
|
350
|
+
if resolved < minimum:
|
|
351
|
+
raise ValueError(f"{name} must be at least {minimum} bytes")
|
|
352
|
+
|
|
353
|
+
return resolved
|
|
354
|
+
|
|
355
|
+
@tracked("_account_span_size")
|
|
356
|
+
async def _account_span_size(self, span: Span) -> int:
|
|
357
|
+
if self._custom_span_size_estimator is not None:
|
|
358
|
+
size = max(int(self._custom_span_size_estimator(span)), 0)
|
|
359
|
+
else:
|
|
360
|
+
size = estimate_model_size(span)
|
|
361
|
+
|
|
362
|
+
self._span_bytes_by_rollout[span.rollout_id] += size
|
|
363
|
+
self._total_span_bytes += size
|
|
364
|
+
return size
|
|
365
|
+
|
|
366
|
+
@tracked("_maybe_evict_spans")
|
|
367
|
+
async def _maybe_evict_spans(self, collections: InMemoryLightningCollections) -> None:
|
|
368
|
+
if self._total_span_bytes <= self._eviction_threshold_bytes:
|
|
369
|
+
return
|
|
370
|
+
|
|
371
|
+
logger.info(
|
|
372
|
+
f"Total span bytes: {self._total_span_bytes}, eviction threshold: {self._eviction_threshold_bytes}, "
|
|
373
|
+
f"safe threshold: {self._safe_threshold_bytes}. Evicting spans..."
|
|
374
|
+
)
|
|
375
|
+
candidates: List[tuple[float, str]] = [
|
|
376
|
+
(start_time, rollout_id) for rollout_id, start_time in self._start_time_by_rollout.items()
|
|
377
|
+
]
|
|
378
|
+
candidates.sort()
|
|
379
|
+
|
|
380
|
+
logger.info(f"Evicting spans for {len(candidates)} rollouts to free up memory...")
|
|
381
|
+
memory_consumed_before = self._total_span_bytes
|
|
382
|
+
for _, rollout_id in candidates:
|
|
383
|
+
if self._total_span_bytes <= self._safe_threshold_bytes:
|
|
384
|
+
break
|
|
385
|
+
logger.debug(f"Evicting spans for rollout {rollout_id} to free up memory...")
|
|
386
|
+
await self._evict_spans_for_rollout(collections, rollout_id)
|
|
387
|
+
logger.info(f"Freed up {memory_consumed_before - self._total_span_bytes} bytes of memory")
|
|
388
|
+
|
|
389
|
+
@tracked("_evict_spans_for_rollout")
|
|
390
|
+
async def _evict_spans_for_rollout(self, collections: InMemoryLightningCollections, rollout_id: str) -> None:
|
|
391
|
+
await collections.evict_spans_for_rollout(rollout_id)
|
|
392
|
+
removed_bytes = self._span_bytes_by_rollout.pop(rollout_id, 0)
|
|
393
|
+
if removed_bytes > 0:
|
|
394
|
+
# There is something removed for real
|
|
395
|
+
self._total_span_bytes = max(self._total_span_bytes - removed_bytes, 0)
|
|
396
|
+
self._evicted_rollout_span_sets.add(rollout_id)
|
mantisdk/store/mongo.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import hashlib
|
|
7
|
+
import logging
|
|
8
|
+
import time
|
|
9
|
+
import uuid
|
|
10
|
+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TypeVar, Union
|
|
11
|
+
|
|
12
|
+
from mantisdk.types import Attempt, AttemptedRollout, Rollout
|
|
13
|
+
from mantisdk.utils.metrics import MetricsBackend
|
|
14
|
+
|
|
15
|
+
from .base import LightningStoreCapabilities, is_finished
|
|
16
|
+
from .collection.mongo import MongoClientPool, MongoLightningCollections
|
|
17
|
+
from .collection_based import CollectionBasedLightningStore, healthcheck_before, tracked
|
|
18
|
+
|
|
19
|
+
T_callable = TypeVar("T_callable", bound=Callable[..., Any])
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _generate_partition_id() -> str:
|
|
25
|
+
return "pt-" + hashlib.sha1(uuid.uuid4().bytes).hexdigest()[:12]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class MongoLightningStore(CollectionBasedLightningStore[MongoLightningCollections]):
|
|
29
|
+
"""
|
|
30
|
+
MongoDB implementation of LightningStore using MongoDB collections.
|
|
31
|
+
Data is persistent and can be shared between multiple processes.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
mongo_uri: MongoDB connection string (defaults to local replica set).
|
|
35
|
+
mongo_client_kwargs: Extra keyword arguments forwarded to `AsyncMongoClient`.
|
|
36
|
+
database_name: The MongoDB database name. Defaults to ``mantisdk``.
|
|
37
|
+
partition_id: The partition id. Useful when sharing the database among multiple Mantisdk trainers.
|
|
38
|
+
tracker: The metrics tracker to use.
|
|
39
|
+
scan_debounce_seconds: The debounce time for the scan for unhealthy rollouts.
|
|
40
|
+
Set to 0 to disable debouncing.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
*,
|
|
46
|
+
mongo_uri: str = "mongodb://localhost:27017/?replicaSet=rs0",
|
|
47
|
+
mongo_client_kwargs: Mapping[str, Any] | None = None,
|
|
48
|
+
database_name: str | None = None,
|
|
49
|
+
partition_id: str | None = None,
|
|
50
|
+
tracker: MetricsBackend | None = None,
|
|
51
|
+
scan_debounce_seconds: float = 10.0,
|
|
52
|
+
) -> None:
|
|
53
|
+
self._mongo_uri = mongo_uri
|
|
54
|
+
self._mongo_client_kwargs = dict(mongo_client_kwargs or {})
|
|
55
|
+
|
|
56
|
+
if database_name is None:
|
|
57
|
+
database_name = "mantisdk"
|
|
58
|
+
logger.info("No database name provided, using default 'mantisdk'")
|
|
59
|
+
|
|
60
|
+
if partition_id is None:
|
|
61
|
+
partition_id = _generate_partition_id()
|
|
62
|
+
logger.info("No partition id provided, generated a new one: %s", partition_id)
|
|
63
|
+
|
|
64
|
+
self._client_pool = MongoClientPool[Mapping[str, Any]](
|
|
65
|
+
mongo_uri=self._mongo_uri,
|
|
66
|
+
mongo_client_kwargs=self._mongo_client_kwargs,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
super().__init__(
|
|
70
|
+
collections=MongoLightningCollections(
|
|
71
|
+
self._client_pool,
|
|
72
|
+
database_name,
|
|
73
|
+
partition_id,
|
|
74
|
+
tracker=tracker,
|
|
75
|
+
),
|
|
76
|
+
tracker=tracker,
|
|
77
|
+
scan_debounce_seconds=scan_debounce_seconds,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def capabilities(self) -> LightningStoreCapabilities:
|
|
82
|
+
"""Return the capabilities of the store."""
|
|
83
|
+
return LightningStoreCapabilities(
|
|
84
|
+
thread_safe=True,
|
|
85
|
+
async_safe=True,
|
|
86
|
+
zero_copy=True,
|
|
87
|
+
otlp_traces=False,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
async def close(self) -> None:
|
|
91
|
+
"""Close the store by closing the client pool."""
|
|
92
|
+
await self._client_pool.close()
|
|
93
|
+
|
|
94
|
+
@tracked("wait_for_rollouts")
|
|
95
|
+
@healthcheck_before
|
|
96
|
+
async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]:
|
|
97
|
+
"""Wait for specified rollouts to complete with a timeout.
|
|
98
|
+
|
|
99
|
+
Concurrently wait for all rollouts to complete with a timeout.
|
|
100
|
+
"""
|
|
101
|
+
start_time = time.time()
|
|
102
|
+
current_time = start_time
|
|
103
|
+
deadline = start_time + timeout if timeout is not None else None
|
|
104
|
+
|
|
105
|
+
finished_rollouts: Dict[str, Rollout] = {}
|
|
106
|
+
unfinished_rollout_ids = set(rollout_ids)
|
|
107
|
+
|
|
108
|
+
while deadline is None or current_time <= deadline:
|
|
109
|
+
async with self.collections.atomic(
|
|
110
|
+
mode="r", snapshot=self._read_snapshot, labels=["rollouts"]
|
|
111
|
+
) as collections:
|
|
112
|
+
# Query the rollouts that are not finished in a single query
|
|
113
|
+
rollouts = await collections.rollouts.query(
|
|
114
|
+
filter={"rollout_id": {"within": list(unfinished_rollout_ids)}}
|
|
115
|
+
)
|
|
116
|
+
for rollout in rollouts.items:
|
|
117
|
+
if is_finished(rollout):
|
|
118
|
+
finished_rollouts[rollout.rollout_id] = rollout
|
|
119
|
+
unfinished_rollout_ids.remove(rollout.rollout_id)
|
|
120
|
+
|
|
121
|
+
if not unfinished_rollout_ids:
|
|
122
|
+
break
|
|
123
|
+
|
|
124
|
+
# Poll every 10 seconds by default
|
|
125
|
+
# Minus 0.1 to make sure the time is still sufficient for another call
|
|
126
|
+
rest_time = max(0.01, min(deadline - time.time() - 0.1, 10.0)) if deadline is not None else 10.0
|
|
127
|
+
await asyncio.sleep(rest_time)
|
|
128
|
+
current_time = time.time()
|
|
129
|
+
|
|
130
|
+
# Logging will help debugging when there are stuck rollouts.
|
|
131
|
+
logger.debug(
|
|
132
|
+
"Waiting for rollouts. Number of finished rollouts: %d; number of unfinished rollouts: %d",
|
|
133
|
+
len(finished_rollouts),
|
|
134
|
+
len(unfinished_rollout_ids),
|
|
135
|
+
)
|
|
136
|
+
if len(unfinished_rollout_ids) < 30:
|
|
137
|
+
logger.debug("Unfinished rollouts: %s", unfinished_rollout_ids)
|
|
138
|
+
|
|
139
|
+
# Reorder the rollouts to match the input order
|
|
140
|
+
return [finished_rollouts[rollout_id] for rollout_id in rollout_ids if rollout_id in finished_rollouts]
|
|
141
|
+
|
|
142
|
+
@tracked("_unlocked_many_rollouts_to_attempted_rollouts")
|
|
143
|
+
async def _unlocked_many_rollouts_to_attempted_rollouts(
|
|
144
|
+
self, collections: MongoLightningCollections, rollouts: Sequence[Rollout]
|
|
145
|
+
) -> List[Union[Rollout, AttemptedRollout]]:
|
|
146
|
+
"""Query the latest attempts for the rollouts, and attach them to the rollout objects."""
|
|
147
|
+
async with collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["attempts"]) as collections:
|
|
148
|
+
attempts = await collections.attempts.query(
|
|
149
|
+
filter={"rollout_id": {"within": [rollout.rollout_id for rollout in rollouts]}},
|
|
150
|
+
sort={"name": "sequence_id", "order": "desc"},
|
|
151
|
+
)
|
|
152
|
+
latest_attempts: Dict[str, Attempt] = {}
|
|
153
|
+
for attempt in attempts:
|
|
154
|
+
if attempt.rollout_id not in latest_attempts:
|
|
155
|
+
latest_attempts[attempt.rollout_id] = attempt
|
|
156
|
+
# Otherwise we ignore the attempt because there's already a newer attempt
|
|
157
|
+
|
|
158
|
+
return [
|
|
159
|
+
(
|
|
160
|
+
AttemptedRollout(**rollout.model_dump(), attempt=latest_attempts[rollout.rollout_id])
|
|
161
|
+
if rollout.rollout_id in latest_attempts
|
|
162
|
+
else rollout
|
|
163
|
+
)
|
|
164
|
+
for rollout in rollouts
|
|
165
|
+
]
|
mantisdk/store/sqlite.py
ADDED