mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,58 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, Optional, Protocol, runtime_checkable
6
+
7
+ from mantisdk.types import Attempt, AttemptedRollout, NamedResources, ResourcesUpdate, Rollout, Span
8
+
9
+
10
+ @runtime_checkable
11
+ class StorageListener(Protocol):
12
+ """Protocol for listening to storage events.
13
+
14
+ Listeners can be attached to a LightningStore to observe state changes
15
+ and perform side effects (logging, tracking, etc.) without modifying
16
+ the core storage logic.
17
+ """
18
+
19
+ @property
20
+ def capabilities(self) -> Dict[str, bool]:
21
+ """Return the capabilities of the listener (e.g., {"otlp_traces": True})."""
22
+ ...
23
+
24
+ def otlp_traces_endpoint(self) -> Optional[str]:
25
+ """Return OTLP endpoint if supported, else None."""
26
+ ...
27
+
28
+ def get_otlp_headers(self) -> Dict[str, str]:
29
+ """Return OTLP headers if supported, else empty dict."""
30
+ ...
31
+
32
+ async def on_job_created(self, job_id: str, project_id: Optional[str] = None) -> None:
33
+ """Called when the store/job is initialized."""
34
+ ...
35
+
36
+ async def on_rollout_created(self, rollout: Rollout) -> None:
37
+ """Called when a rollout is created (start or enqueue)."""
38
+ ...
39
+
40
+ async def on_rollout_updated(self, rollout: Rollout) -> None:
41
+ """Called when a rollout is updated (status change, etc.)."""
42
+ ...
43
+
44
+ async def on_attempt_created(self, attempt: Attempt) -> None:
45
+ """Called when an attempt is created."""
46
+ ...
47
+
48
+ async def on_attempt_updated(self, attempt: Attempt, rollout_id: str) -> None:
49
+ """Called when an attempt is updated."""
50
+ ...
51
+
52
+ async def on_span_created(self, span: Span) -> None:
53
+ """Called when a span is added."""
54
+ ...
55
+
56
+ async def on_resource_registered(self, resource: ResourcesUpdate) -> None:
57
+ """Called when a resource snapshot is registered/updated."""
58
+ ...
@@ -0,0 +1,396 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import sys
8
+ from collections.abc import Iterable
9
+ from collections.abc import Mapping as MappingABC
10
+ from typing import (
11
+ TYPE_CHECKING,
12
+ Any,
13
+ Callable,
14
+ Counter,
15
+ Dict,
16
+ List,
17
+ Literal,
18
+ Mapping,
19
+ Optional,
20
+ Sequence,
21
+ Set,
22
+ Tuple,
23
+ TypeVar,
24
+ Union,
25
+ cast,
26
+ )
27
+
28
+ if TYPE_CHECKING:
29
+ from .listener import StorageListener
30
+
31
+ import aiologic
32
+ from pydantic import BaseModel
33
+
34
+ from mantisdk.types import AttemptedRollout, NamedResources, PaginatedResult, ResourcesUpdate, Rollout, Span
35
+ from mantisdk.utils.metrics import MetricsBackend
36
+
37
+ from .base import UNSET, LightningStoreCapabilities, LightningStoreStatistics, Unset, is_finished, is_running
38
+ from .collection import InMemoryLightningCollections
39
+ from .collection_based import CollectionBasedLightningStore, tracked
40
+
41
+ T_callable = TypeVar("T_callable", bound=Callable[..., Any])
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ def estimate_model_size(obj: Any) -> int:
47
+ """Rough recursive size estimate for Pydantic BaseModel instances."""
48
+
49
+ if isinstance(obj, BaseModel):
50
+ values = cast(Iterable[Any], obj.__dict__.values())
51
+ return sum(estimate_model_size(value) for value in values) + sys.getsizeof(cast(object, obj))
52
+ if isinstance(obj, MappingABC):
53
+ mapping = cast(Mapping[Any, Any], obj)
54
+ return sum(estimate_model_size(value) for value in mapping.values()) + sys.getsizeof(cast(object, obj))
55
+ if isinstance(obj, (list, tuple, set)):
56
+ iterable = cast(Iterable[Any], obj)
57
+ return sum(estimate_model_size(value) for value in iterable) + sys.getsizeof(cast(object, obj))
58
+ return sys.getsizeof(cast(object, obj))
59
+
60
+
61
+ def _detect_total_memory_bytes() -> int:
62
+ """Best-effort detection of the total available system memory in bytes."""
63
+
64
+ try:
65
+ import psutil
66
+
67
+ return int(psutil.virtual_memory().total)
68
+ except ImportError:
69
+ # Fallback to 8GB if memory cannot be detected.
70
+ logger.error("psutil is not installed. Falling back to 8GB of memory in total.")
71
+ return 8 * 1024**3
72
+
73
+
74
+ class InMemoryLightningStore(CollectionBasedLightningStore[InMemoryLightningCollections]):
75
+ """
76
+ In-memory implementation of LightningStore using Python data structures.
77
+ Thread-safe and async-compatible but data is not persistent.
78
+
79
+ Args:
80
+ thread_safe: Whether the store is thread-safe.
81
+ eviction_memory_threshold: The threshold for evicting spans in bytes.
82
+ By default, it's 70% of the total VRAM available.
83
+ safe_memory_threshold: The threshold for safe memory usage in bytes.
84
+ By default, it's 80% of the eviction threshold.
85
+ span_size_estimator: A function to estimate the size of a span in bytes.
86
+ By default, it's a simple size estimator that uses sys.getsizeof.
87
+ tracker: The metrics tracker to use.
88
+ scan_debounce_seconds: The debounce time for the scan for unhealthy rollouts.
89
+ Set to 0 to disable debouncing.
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ *,
95
+ thread_safe: bool = False,
96
+ eviction_memory_threshold: float | int | None = None,
97
+ safe_memory_threshold: float | int | None = None,
98
+ span_size_estimator: Callable[[Span], int] | None = None,
99
+ tracker: MetricsBackend | None = None,
100
+ scan_debounce_seconds: float = 10.0,
101
+ listeners: Optional[Sequence["StorageListener"]] = None,
102
+ ):
103
+ super().__init__(
104
+ collections=InMemoryLightningCollections(lock_type="thread" if thread_safe else "asyncio", tracker=tracker),
105
+ tracker=tracker,
106
+ scan_debounce_seconds=scan_debounce_seconds,
107
+ listeners=listeners,
108
+ )
109
+
110
+ self._thread_safe = thread_safe
111
+ self._start_time_by_rollout: Dict[str, float] = {}
112
+ self._span_bytes_by_rollout: Dict[str, int] = Counter()
113
+ self._total_span_bytes: int = 0
114
+ self._evicted_rollout_span_sets: Set[str] = set()
115
+
116
+ self._memory_capacity_bytes = _detect_total_memory_bytes()
117
+ if self._memory_capacity_bytes <= 0:
118
+ raise ValueError("Detected memory capacity must be positive")
119
+
120
+ self._eviction_threshold_bytes = self._resolve_memory_threshold(
121
+ eviction_memory_threshold,
122
+ default_ratio=0.7,
123
+ capacity_bytes=self._memory_capacity_bytes,
124
+ name="eviction_memory_threshold",
125
+ minimum=1,
126
+ )
127
+
128
+ if safe_memory_threshold is None:
129
+ safe_memory_threshold = max(int(self._eviction_threshold_bytes * 0.8), 0)
130
+
131
+ self._safe_threshold_bytes = self._resolve_memory_threshold(
132
+ safe_memory_threshold,
133
+ default_ratio=self._eviction_threshold_bytes / self._memory_capacity_bytes,
134
+ capacity_bytes=self._memory_capacity_bytes,
135
+ name="safe_memory_threshold",
136
+ minimum=0,
137
+ )
138
+
139
+ if not (0 <= self._safe_threshold_bytes < self._eviction_threshold_bytes):
140
+ raise ValueError("safe_memory_threshold must be smaller than eviction_memory_threshold")
141
+ self._custom_span_size_estimator = span_size_estimator
142
+
143
+ # Completion tracking for wait_for_rollouts (cross-loop safe)
144
+ self._completion_events: Dict[str, aiologic.Event] = {}
145
+
146
+ # Running rollouts cache, including preparing and running rollouts
147
+ self._running_rollout_ids: Set[str] = set()
148
+
149
+ # Caches the latest resources ID.
150
+ self._latest_resources_id: Union[str, None, Unset] = UNSET
151
+
152
+ @property
153
+ def capabilities(self) -> LightningStoreCapabilities:
154
+ """Return the capabilities of the store.
155
+
156
+ Merges base store capabilities with listener capabilities (e.g., InsightTracker).
157
+ """
158
+ base_caps = LightningStoreCapabilities(
159
+ thread_safe=self._thread_safe,
160
+ async_safe=True,
161
+ zero_copy=False,
162
+ otlp_traces=False,
163
+ )
164
+
165
+ # Merge capabilities from listeners (e.g., InsightTracker provides otlp_traces=True)
166
+ for listener in self.listeners:
167
+ if hasattr(listener, "capabilities"):
168
+ base_caps.update(listener.capabilities)
169
+
170
+ return base_caps
171
+
172
+ async def statistics(self) -> LightningStoreStatistics:
173
+ """Return the statistics of the store."""
174
+ return {
175
+ **(await super().statistics()),
176
+ "total_span_bytes": self._total_span_bytes,
177
+ "eviction_threshold_bytes": self._eviction_threshold_bytes,
178
+ "safe_threshold_bytes": self._safe_threshold_bytes,
179
+ "memory_capacity_bytes": self._memory_capacity_bytes,
180
+ }
181
+
182
+ @tracked("wait_for_rollout")
183
+ async def wait_for_rollout(self, rollout_id: str, timeout: Optional[float] = None) -> Optional[Rollout]:
184
+ """Wait for a specific rollout to complete with a timeout."""
185
+ async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["rollouts"]) as collections:
186
+ rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
187
+ if rollout and is_finished(rollout):
188
+ return rollout
189
+
190
+ if timeout is not None and timeout <= 0:
191
+ return None
192
+
193
+ # If not completed and we have an event, wait for completion
194
+ if rollout_id in self._completion_events:
195
+ evt = self._completion_events[rollout_id]
196
+
197
+ # Wait for the event with proper timeout handling
198
+ # evt.wait() returns True if event was set, False if timeout occurred
199
+ if timeout is None:
200
+ # Wait indefinitely by polling with finite timeouts
201
+ # This allows threads to exit cleanly on shutdown
202
+ while True:
203
+ result = await asyncio.to_thread(evt.wait, 10.0) # Poll every 10 seconds
204
+ if result: # Event was set
205
+ break
206
+ # Loop and check again (continues indefinitely since timeout=None)
207
+ else:
208
+ # Wait with the specified timeout
209
+ result = await asyncio.to_thread(evt.wait, timeout)
210
+
211
+ # If event was set (not timeout), check if rollout is finished
212
+ if result:
213
+ async with self.collections.atomic(
214
+ mode="r", snapshot=self._read_snapshot, labels=["rollouts"]
215
+ ) as collections:
216
+ rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
217
+ if rollout and is_finished(rollout):
218
+ return rollout
219
+
220
+ return None
221
+
222
+ @tracked("add_resources_inmemory")
223
+ async def add_resources(self, resources: NamedResources) -> ResourcesUpdate:
224
+ ret = await super().add_resources(resources)
225
+ async with self.collections.atomic(mode="rw", snapshot=self._read_snapshot, labels=["resources"]):
226
+ self._latest_resources_id = ret.resources_id
227
+ return ret
228
+
229
+ @tracked("update_resources_inmemory")
230
+ async def update_resources(self, resources_id: str, resources: NamedResources) -> ResourcesUpdate:
231
+ ret = await super().update_resources(resources_id, resources)
232
+ async with self.collections.atomic(mode="rw", snapshot=self._read_snapshot, labels=["resources"]):
233
+ self._latest_resources_id = ret.resources_id
234
+ return ret
235
+
236
+ @tracked("_post_update_rollout_inmemory")
237
+ async def _post_update_rollout(
238
+ self, rollouts: Sequence[Tuple[Rollout, Sequence[str]]], skip_enqueue: bool = False
239
+ ) -> None:
240
+ """Update the running rollout ids set when the rollout updates."""
241
+ await super()._post_update_rollout(rollouts, skip_enqueue=skip_enqueue)
242
+ async with self.collections.atomic(mode="rw", snapshot=self._read_snapshot, labels=["rollouts"]):
243
+ for rollout, _ in rollouts:
244
+ if is_running(rollout):
245
+ self._running_rollout_ids.add(rollout.rollout_id)
246
+ else:
247
+ self._running_rollout_ids.discard(rollout.rollout_id)
248
+
249
+ if is_finished(rollout):
250
+ self._completion_events.setdefault(rollout.rollout_id, aiologic.Event())
251
+ self._completion_events[rollout.rollout_id].set()
252
+ else:
253
+ self._completion_events.setdefault(rollout.rollout_id, aiologic.Event())
254
+ # Rollout status can never transition from finished to running (unlike attempt)
255
+ # so we don't need to clear the completion event even in case of retrying.
256
+
257
+ if rollout.rollout_id not in self._start_time_by_rollout:
258
+ self._start_time_by_rollout[rollout.rollout_id] = rollout.start_time
259
+
260
+ @tracked("_unlocked_query_rollouts_by_rollout_ids")
261
+ async def _unlocked_query_rollouts_by_rollout_ids(
262
+ self, collections: InMemoryLightningCollections, rollout_ids: Sequence[str]
263
+ ) -> List[Rollout]:
264
+ """Always use exact. This is faster than within filter for in-memory store."""
265
+ if len(rollout_ids) == 0:
266
+ return []
267
+
268
+ rollouts = [await collections.rollouts.get({"rollout_id": {"exact": rollout_id}}) for rollout_id in rollout_ids]
269
+ return [rollout for rollout in rollouts if rollout is not None]
270
+
271
+ @tracked("_unlocked_get_running_rollouts")
272
+ async def _unlocked_get_running_rollouts(self, collections: InMemoryLightningCollections) -> List[AttemptedRollout]:
273
+ """Accelerated version of `_unlocked_get_running_rollouts` for in-memory store. Used for healthcheck."""
274
+ async with self.collections.atomic(
275
+ mode="r", snapshot=self._read_snapshot, labels=["rollouts", "attempts"]
276
+ ) as collections:
277
+ rollouts = await self._unlocked_query_rollouts_by_rollout_ids(collections, list(self._running_rollout_ids))
278
+ running_rollouts: List[AttemptedRollout] = []
279
+ for rollout in rollouts:
280
+ latest_attempt = await collections.attempts.get(
281
+ filter={"rollout_id": {"exact": rollout.rollout_id}},
282
+ sort={"name": "sequence_id", "order": "desc"},
283
+ )
284
+ if not latest_attempt:
285
+ # The rollout is running but has no attempts, this should not happen
286
+ logger.error(f"Rollout {rollout.rollout_id} is running but has no attempts")
287
+ continue
288
+ running_rollouts.append(AttemptedRollout(**rollout.model_dump(), attempt=latest_attempt))
289
+ return running_rollouts
290
+
291
+ @tracked("query_spans_inmemory") # Since this method calls super, we need to track it separately
292
+ async def query_spans(
293
+ self,
294
+ rollout_id: str,
295
+ attempt_id: str | Literal["latest"] | None = None,
296
+ **kwargs: Any,
297
+ ) -> PaginatedResult[Span]:
298
+ if rollout_id in self._evicted_rollout_span_sets:
299
+ raise RuntimeError(f"Spans for rollout {rollout_id} have been evicted")
300
+ return await super().query_spans(rollout_id, attempt_id, **kwargs)
301
+
302
+ @tracked("_post_add_spans")
303
+ async def _post_add_spans(self, spans: Sequence[Span], rollout_id: str, attempt_id: str) -> None:
304
+ """In-memory store needs to maintain the span data in memory, and evict spans when memory is low."""
305
+
306
+ await super()._post_add_spans(spans, rollout_id, attempt_id)
307
+ async with self.collections.atomic(
308
+ mode="rw", snapshot=self._read_snapshot, labels=["rollouts", "spans"]
309
+ ) as collections:
310
+ for span in spans:
311
+ await self._account_span_size(span)
312
+ await self._maybe_evict_spans(collections)
313
+
314
+ @tracked("_get_latest_resources_inmemory")
315
+ async def _get_latest_resources(self) -> Optional[ResourcesUpdate]:
316
+ if isinstance(self._latest_resources_id, Unset):
317
+ return await super()._get_latest_resources()
318
+ if self._latest_resources_id is not None:
319
+ async with self.collections.atomic(
320
+ mode="r", snapshot=self._read_snapshot, labels=["resources"]
321
+ ) as collections:
322
+ return await collections.resources.get(filter={"resources_id": {"exact": self._latest_resources_id}})
323
+ return None
324
+
325
+ @staticmethod
326
+ def _resolve_memory_threshold(
327
+ value: float | int | None,
328
+ *,
329
+ default_ratio: float,
330
+ capacity_bytes: int,
331
+ name: str,
332
+ minimum: int,
333
+ ) -> int:
334
+ if value is None:
335
+ resolved = int(capacity_bytes * default_ratio)
336
+ elif isinstance(value, float):
337
+ if minimum == 0:
338
+ if not (0 <= value <= 1):
339
+ raise ValueError(f"{name} ratio must be between 0 and 1 inclusive")
340
+ else:
341
+ if not (0 < value <= 1):
342
+ raise ValueError(f"{name} ratio must be greater than 0 and at most 1")
343
+ resolved = int(capacity_bytes * value)
344
+ else:
345
+ value_int = value
346
+ if value_int < 0:
347
+ raise ValueError(f"{name} must be non-negative")
348
+ resolved = value_int
349
+
350
+ if resolved < minimum:
351
+ raise ValueError(f"{name} must be at least {minimum} bytes")
352
+
353
+ return resolved
354
+
355
+ @tracked("_account_span_size")
356
+ async def _account_span_size(self, span: Span) -> int:
357
+ if self._custom_span_size_estimator is not None:
358
+ size = max(int(self._custom_span_size_estimator(span)), 0)
359
+ else:
360
+ size = estimate_model_size(span)
361
+
362
+ self._span_bytes_by_rollout[span.rollout_id] += size
363
+ self._total_span_bytes += size
364
+ return size
365
+
366
+ @tracked("_maybe_evict_spans")
367
+ async def _maybe_evict_spans(self, collections: InMemoryLightningCollections) -> None:
368
+ if self._total_span_bytes <= self._eviction_threshold_bytes:
369
+ return
370
+
371
+ logger.info(
372
+ f"Total span bytes: {self._total_span_bytes}, eviction threshold: {self._eviction_threshold_bytes}, "
373
+ f"safe threshold: {self._safe_threshold_bytes}. Evicting spans..."
374
+ )
375
+ candidates: List[tuple[float, str]] = [
376
+ (start_time, rollout_id) for rollout_id, start_time in self._start_time_by_rollout.items()
377
+ ]
378
+ candidates.sort()
379
+
380
+ logger.info(f"Evicting spans for {len(candidates)} rollouts to free up memory...")
381
+ memory_consumed_before = self._total_span_bytes
382
+ for _, rollout_id in candidates:
383
+ if self._total_span_bytes <= self._safe_threshold_bytes:
384
+ break
385
+ logger.debug(f"Evicting spans for rollout {rollout_id} to free up memory...")
386
+ await self._evict_spans_for_rollout(collections, rollout_id)
387
+ logger.info(f"Freed up {memory_consumed_before - self._total_span_bytes} bytes of memory")
388
+
389
+ @tracked("_evict_spans_for_rollout")
390
+ async def _evict_spans_for_rollout(self, collections: InMemoryLightningCollections, rollout_id: str) -> None:
391
+ await collections.evict_spans_for_rollout(rollout_id)
392
+ removed_bytes = self._span_bytes_by_rollout.pop(rollout_id, 0)
393
+ if removed_bytes > 0:
394
+ # There is something removed for real
395
+ self._total_span_bytes = max(self._total_span_bytes - removed_bytes, 0)
396
+ self._evicted_rollout_span_sets.add(rollout_id)
@@ -0,0 +1,165 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import hashlib
7
+ import logging
8
+ import time
9
+ import uuid
10
+ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TypeVar, Union
11
+
12
+ from mantisdk.types import Attempt, AttemptedRollout, Rollout
13
+ from mantisdk.utils.metrics import MetricsBackend
14
+
15
+ from .base import LightningStoreCapabilities, is_finished
16
+ from .collection.mongo import MongoClientPool, MongoLightningCollections
17
+ from .collection_based import CollectionBasedLightningStore, healthcheck_before, tracked
18
+
19
+ T_callable = TypeVar("T_callable", bound=Callable[..., Any])
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def _generate_partition_id() -> str:
25
+ return "pt-" + hashlib.sha1(uuid.uuid4().bytes).hexdigest()[:12]
26
+
27
+
28
+ class MongoLightningStore(CollectionBasedLightningStore[MongoLightningCollections]):
29
+ """
30
+ MongoDB implementation of LightningStore using MongoDB collections.
31
+ Data is persistent and can be shared between multiple processes.
32
+
33
+ Args:
34
+ mongo_uri: MongoDB connection string (defaults to local replica set).
35
+ mongo_client_kwargs: Extra keyword arguments forwarded to `AsyncMongoClient`.
36
+ database_name: The MongoDB database name. Defaults to ``mantisdk``.
37
+ partition_id: The partition id. Useful when sharing the database among multiple Mantisdk trainers.
38
+ tracker: The metrics tracker to use.
39
+ scan_debounce_seconds: The debounce time for the scan for unhealthy rollouts.
40
+ Set to 0 to disable debouncing.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ *,
46
+ mongo_uri: str = "mongodb://localhost:27017/?replicaSet=rs0",
47
+ mongo_client_kwargs: Mapping[str, Any] | None = None,
48
+ database_name: str | None = None,
49
+ partition_id: str | None = None,
50
+ tracker: MetricsBackend | None = None,
51
+ scan_debounce_seconds: float = 10.0,
52
+ ) -> None:
53
+ self._mongo_uri = mongo_uri
54
+ self._mongo_client_kwargs = dict(mongo_client_kwargs or {})
55
+
56
+ if database_name is None:
57
+ database_name = "mantisdk"
58
+ logger.info("No database name provided, using default 'mantisdk'")
59
+
60
+ if partition_id is None:
61
+ partition_id = _generate_partition_id()
62
+ logger.info("No partition id provided, generated a new one: %s", partition_id)
63
+
64
+ self._client_pool = MongoClientPool[Mapping[str, Any]](
65
+ mongo_uri=self._mongo_uri,
66
+ mongo_client_kwargs=self._mongo_client_kwargs,
67
+ )
68
+
69
+ super().__init__(
70
+ collections=MongoLightningCollections(
71
+ self._client_pool,
72
+ database_name,
73
+ partition_id,
74
+ tracker=tracker,
75
+ ),
76
+ tracker=tracker,
77
+ scan_debounce_seconds=scan_debounce_seconds,
78
+ )
79
+
80
+ @property
81
+ def capabilities(self) -> LightningStoreCapabilities:
82
+ """Return the capabilities of the store."""
83
+ return LightningStoreCapabilities(
84
+ thread_safe=True,
85
+ async_safe=True,
86
+ zero_copy=True,
87
+ otlp_traces=False,
88
+ )
89
+
90
+ async def close(self) -> None:
91
+ """Close the store by closing the client pool."""
92
+ await self._client_pool.close()
93
+
94
+ @tracked("wait_for_rollouts")
95
+ @healthcheck_before
96
+ async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]:
97
+ """Wait for specified rollouts to complete with a timeout.
98
+
99
+ Concurrently wait for all rollouts to complete with a timeout.
100
+ """
101
+ start_time = time.time()
102
+ current_time = start_time
103
+ deadline = start_time + timeout if timeout is not None else None
104
+
105
+ finished_rollouts: Dict[str, Rollout] = {}
106
+ unfinished_rollout_ids = set(rollout_ids)
107
+
108
+ while deadline is None or current_time <= deadline:
109
+ async with self.collections.atomic(
110
+ mode="r", snapshot=self._read_snapshot, labels=["rollouts"]
111
+ ) as collections:
112
+ # Query the rollouts that are not finished in a single query
113
+ rollouts = await collections.rollouts.query(
114
+ filter={"rollout_id": {"within": list(unfinished_rollout_ids)}}
115
+ )
116
+ for rollout in rollouts.items:
117
+ if is_finished(rollout):
118
+ finished_rollouts[rollout.rollout_id] = rollout
119
+ unfinished_rollout_ids.remove(rollout.rollout_id)
120
+
121
+ if not unfinished_rollout_ids:
122
+ break
123
+
124
+ # Poll every 10 seconds by default
125
+ # Minus 0.1 to make sure the time is still sufficient for another call
126
+ rest_time = max(0.01, min(deadline - time.time() - 0.1, 10.0)) if deadline is not None else 10.0
127
+ await asyncio.sleep(rest_time)
128
+ current_time = time.time()
129
+
130
+ # Logging will help debugging when there are stuck rollouts.
131
+ logger.debug(
132
+ "Waiting for rollouts. Number of finished rollouts: %d; number of unfinished rollouts: %d",
133
+ len(finished_rollouts),
134
+ len(unfinished_rollout_ids),
135
+ )
136
+ if len(unfinished_rollout_ids) < 30:
137
+ logger.debug("Unfinished rollouts: %s", unfinished_rollout_ids)
138
+
139
+ # Reorder the rollouts to match the input order
140
+ return [finished_rollouts[rollout_id] for rollout_id in rollout_ids if rollout_id in finished_rollouts]
141
+
142
+ @tracked("_unlocked_many_rollouts_to_attempted_rollouts")
143
+ async def _unlocked_many_rollouts_to_attempted_rollouts(
144
+ self, collections: MongoLightningCollections, rollouts: Sequence[Rollout]
145
+ ) -> List[Union[Rollout, AttemptedRollout]]:
146
+ """Query the latest attempts for the rollouts, and attach them to the rollout objects."""
147
+ async with collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["attempts"]) as collections:
148
+ attempts = await collections.attempts.query(
149
+ filter={"rollout_id": {"within": [rollout.rollout_id for rollout in rollouts]}},
150
+ sort={"name": "sequence_id", "order": "desc"},
151
+ )
152
+ latest_attempts: Dict[str, Attempt] = {}
153
+ for attempt in attempts:
154
+ if attempt.rollout_id not in latest_attempts:
155
+ latest_attempts[attempt.rollout_id] = attempt
156
+ # Otherwise we ignore the attempt because there's already a newer attempt
157
+
158
+ return [
159
+ (
160
+ AttemptedRollout(**rollout.model_dump(), attempt=latest_attempts[rollout.rollout_id])
161
+ if rollout.rollout_id in latest_attempts
162
+ else rollout
163
+ )
164
+ for rollout in rollouts
165
+ ]
@@ -0,0 +1,3 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ # TODO: Implement this