mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import TYPE_CHECKING, Any, List, Literal, Optional
|
|
9
|
+
|
|
10
|
+
from mantisdk.types import Attempt, Dataset, Rollout, RolloutStatus, Span
|
|
11
|
+
|
|
12
|
+
from .base import Algorithm
|
|
13
|
+
from .utils import with_llm_proxy, with_store
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from mantisdk.llm_proxy import LLMProxy
|
|
17
|
+
from mantisdk.store.base import LightningStore
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
__all__ = ["FastAlgorithm", "Baseline"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FastAlgorithm(Algorithm):
|
|
25
|
+
"""Base class for lightweight algorithms optimised for developer workflows.
|
|
26
|
+
|
|
27
|
+
Fast algorithms prioritise short feedback loops so an agent developer can run
|
|
28
|
+
small-scale experiments without waiting for long-running training jobs to
|
|
29
|
+
finish.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _timestamp_to_iso_str(timestamp: float) -> str:
|
|
34
|
+
return datetime.fromtimestamp(timestamp).isoformat()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Baseline(FastAlgorithm):
|
|
38
|
+
"""Reference implementation that streams the full dataset through the rollout queue.
|
|
39
|
+
|
|
40
|
+
The baseline algorithm batches task submissions, waits for each rollout to
|
|
41
|
+
finish, and logs every collected span and reward. It is primarily useful as
|
|
42
|
+
a smoke test for the platform plumbing rather than a performant trainer.
|
|
43
|
+
|
|
44
|
+
The baseline algorithm will auto-start a LLM proxy if one is provided and not yet started.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
n_epochs: Number of dataset passes to execute for both the train and val
|
|
48
|
+
splits during developer experiments.
|
|
49
|
+
train_split: Fraction of the concatenated dataset to treat as training
|
|
50
|
+
data. Must be strictly between 0 and 1.
|
|
51
|
+
polling_interval: Interval, in seconds, to poll the store for queue
|
|
52
|
+
depth and rollout completion.
|
|
53
|
+
max_queue_length: Number of rollouts allowed to wait in the queue before
|
|
54
|
+
throttling additional submissions.
|
|
55
|
+
span_verbosity: Level of detail to include when logging span metadata.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If `train_split` falls outside the `(0, 1)` interval.
|
|
59
|
+
|
|
60
|
+
Examples:
|
|
61
|
+
```python
|
|
62
|
+
from mantisdk.algorithm.fast import Baseline
|
|
63
|
+
|
|
64
|
+
algorithm = Baseline(n_epochs=2, train_split=0.8, span_verbosity="key_values")
|
|
65
|
+
trainer.fit(algorithm, train_dataset=my_train, val_dataset=my_val)
|
|
66
|
+
```
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
*,
|
|
72
|
+
n_epochs: int = 1,
|
|
73
|
+
train_split: float = 0.5,
|
|
74
|
+
polling_interval: float = 5.0,
|
|
75
|
+
max_queue_length: int = 4,
|
|
76
|
+
span_verbosity: Literal["keys", "key_values", "none"] = "keys",
|
|
77
|
+
) -> None:
|
|
78
|
+
super().__init__()
|
|
79
|
+
self.n_epochs = n_epochs
|
|
80
|
+
self.train_split = train_split
|
|
81
|
+
self.polling_interval = polling_interval
|
|
82
|
+
self.max_queue_length = max_queue_length
|
|
83
|
+
self.span_verbosity = span_verbosity
|
|
84
|
+
if not (0.0 < self.train_split < 1.0):
|
|
85
|
+
raise ValueError("train_split must be between 0 and 1.")
|
|
86
|
+
|
|
87
|
+
self._finished_rollout_count = 0
|
|
88
|
+
|
|
89
|
+
def _span_to_string(self, rollout_id: str, attempt: Attempt, span: Span) -> str:
|
|
90
|
+
"""Format a span for logging based on the configured verbosity."""
|
|
91
|
+
if self.span_verbosity == "none":
|
|
92
|
+
return ""
|
|
93
|
+
|
|
94
|
+
prefix_msg = f"[Rollout {rollout_id} | Attempt {attempt.attempt_id} | Span {span.span_id}] #{span.sequence_id} ({span.name}) "
|
|
95
|
+
elapsed = f"{span.end_time - span.start_time:.2f}" if span.start_time and span.end_time else "unknown"
|
|
96
|
+
|
|
97
|
+
msg = (
|
|
98
|
+
prefix_msg
|
|
99
|
+
+ f"From {_timestamp_to_iso_str(span.start_time) if span.start_time else 'unknown'}, "
|
|
100
|
+
+ f"to {_timestamp_to_iso_str(span.end_time) if span.end_time else 'unknown'}, "
|
|
101
|
+
+ f"{elapsed} seconds. "
|
|
102
|
+
)
|
|
103
|
+
if self.span_verbosity == "key_values":
|
|
104
|
+
msg += f"Attributes: {span.attributes}"
|
|
105
|
+
else:
|
|
106
|
+
msg += f"Attribute keys: {list(span.attributes.keys())}"
|
|
107
|
+
return msg
|
|
108
|
+
|
|
109
|
+
async def _handle_rollout_finish(self, rollout: Rollout) -> None:
|
|
110
|
+
"""Log attempt metadata and emit adapted traces when a rollout ends."""
|
|
111
|
+
store = self.get_store()
|
|
112
|
+
|
|
113
|
+
rollout_id = rollout.rollout_id
|
|
114
|
+
rollout_end_time = rollout.end_time or asyncio.get_event_loop().time()
|
|
115
|
+
logger.info(
|
|
116
|
+
f"[Rollout {rollout_id}] Finished with status {rollout.status} in {rollout_end_time - rollout.start_time:.2f} seconds."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Logs all the attempts and their corresponding spans
|
|
120
|
+
attempts = await store.query_attempts(rollout_id)
|
|
121
|
+
for attempt in attempts:
|
|
122
|
+
logger.info(
|
|
123
|
+
"[Rollout %s | Attempt %s] ID: %s. Status: %s. Worker: %s",
|
|
124
|
+
rollout_id,
|
|
125
|
+
attempt.sequence_id,
|
|
126
|
+
attempt.attempt_id,
|
|
127
|
+
attempt.status,
|
|
128
|
+
attempt.worker_id,
|
|
129
|
+
)
|
|
130
|
+
spans = await store.query_spans(rollout_id=rollout_id)
|
|
131
|
+
for span in spans:
|
|
132
|
+
if self.span_verbosity != "none":
|
|
133
|
+
logger.info(self._span_to_string(rollout.rollout_id, attempt, span))
|
|
134
|
+
|
|
135
|
+
# Attempts to adapt the spans using the adapter if provided
|
|
136
|
+
try:
|
|
137
|
+
adapter = self.get_adapter()
|
|
138
|
+
except ValueError:
|
|
139
|
+
logger.warning("No adapter set for MockAlgorithm. Skipping trace adaptation.")
|
|
140
|
+
adapter = None
|
|
141
|
+
if adapter is not None:
|
|
142
|
+
spans = await store.query_spans(rollout_id=rollout_id, attempt_id="latest")
|
|
143
|
+
transformed_data = adapter.adapt(spans)
|
|
144
|
+
logger.info(f"[Rollout {rollout_id}] Adapted data: {transformed_data}")
|
|
145
|
+
|
|
146
|
+
async def _enqueue_rollouts(
|
|
147
|
+
self, dataset: Dataset[Any], train_indices: List[int], val_indices: List[int], resources_id: str
|
|
148
|
+
) -> None:
|
|
149
|
+
"""Submit rollouts while respecting the maximum queue length."""
|
|
150
|
+
store = self.get_store()
|
|
151
|
+
|
|
152
|
+
for index in train_indices + val_indices:
|
|
153
|
+
queuing_rollouts = await store.query_rollouts(status_in=["queuing", "requeuing"])
|
|
154
|
+
if len(queuing_rollouts) <= 1:
|
|
155
|
+
# Only enqueue a new rollout when there is at most 1 rollout in the queue.
|
|
156
|
+
sample = dataset[index]
|
|
157
|
+
mode = "train" if index in train_indices else "val"
|
|
158
|
+
rollout = await store.enqueue_rollout(input=sample, mode=mode, resources_id=resources_id)
|
|
159
|
+
logger.info(f"[Rollout {rollout.rollout_id}] Enqueued in {mode} mode with sample: {sample}")
|
|
160
|
+
await asyncio.sleep(self.polling_interval)
|
|
161
|
+
|
|
162
|
+
async def _harvest_rollout_spans(self, rollout_id: str):
|
|
163
|
+
"""Poll rollout status updates until completion and log transitions."""
|
|
164
|
+
store = self.get_store()
|
|
165
|
+
last_status: Optional[RolloutStatus] = None
|
|
166
|
+
|
|
167
|
+
while True:
|
|
168
|
+
rollout = await store.get_rollout_by_id(rollout_id)
|
|
169
|
+
if rollout is not None:
|
|
170
|
+
if rollout.status in ["succeeded", "failed", "cancelled"]:
|
|
171
|
+
# Rollout is finished, log all the data.
|
|
172
|
+
await self._handle_rollout_finish(rollout)
|
|
173
|
+
# We are done here.
|
|
174
|
+
self._finished_rollout_count += 1
|
|
175
|
+
logger.info(f"Finished {self._finished_rollout_count} rollouts.")
|
|
176
|
+
break
|
|
177
|
+
|
|
178
|
+
if last_status != rollout.status:
|
|
179
|
+
if last_status is not None:
|
|
180
|
+
logger.info(f"[Rollout {rollout_id}] Status changed to {rollout.status}.")
|
|
181
|
+
else:
|
|
182
|
+
logger.info(f"[Rollout {rollout_id}] Status is initialized to {rollout.status}.")
|
|
183
|
+
last_status = rollout.status
|
|
184
|
+
|
|
185
|
+
else:
|
|
186
|
+
logger.debug(f"[Rollout {rollout_id}] Status is still {rollout.status}.")
|
|
187
|
+
|
|
188
|
+
await asyncio.sleep(self.polling_interval)
|
|
189
|
+
|
|
190
|
+
@with_llm_proxy()
|
|
191
|
+
@with_store
|
|
192
|
+
async def run(
|
|
193
|
+
self,
|
|
194
|
+
store: LightningStore, # Injected by decorator - callers should not provide this parameter
|
|
195
|
+
llm_proxy: Optional[LLMProxy], # Injected by decorator - callers should not provide this parameter
|
|
196
|
+
train_dataset: Optional[Dataset[Any]] = None,
|
|
197
|
+
val_dataset: Optional[Dataset[Any]] = None,
|
|
198
|
+
) -> None:
|
|
199
|
+
"""Execute the baseline loop across the provided datasets."""
|
|
200
|
+
train_dataset_length = len(train_dataset) if train_dataset is not None else 0
|
|
201
|
+
val_dataset_length = len(val_dataset) if val_dataset is not None else 0
|
|
202
|
+
if train_dataset_length == 0 and val_dataset_length == 0:
|
|
203
|
+
logger.error(
|
|
204
|
+
"MockAlgorithm requires at least one dataset. Provide train_dataset or val_dataset before running."
|
|
205
|
+
)
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
concatenated_dataset = [train_dataset[i] for i in range(train_dataset_length) if train_dataset is not None] + [
|
|
209
|
+
val_dataset[i] for i in range(val_dataset_length) if val_dataset is not None
|
|
210
|
+
]
|
|
211
|
+
train_indices = list(range(0, train_dataset_length))
|
|
212
|
+
val_indices = list(range(train_dataset_length, train_dataset_length + val_dataset_length))
|
|
213
|
+
logger.debug(f"Train indices: {train_indices}")
|
|
214
|
+
logger.debug(f"Val indices: {val_indices}")
|
|
215
|
+
|
|
216
|
+
# Currently we only supports a single resource update at the start.
|
|
217
|
+
initial_resources = self.get_initial_resources()
|
|
218
|
+
if initial_resources is not None:
|
|
219
|
+
resource_update = await store.update_resources("default", initial_resources)
|
|
220
|
+
resources_id = resource_update.resources_id
|
|
221
|
+
logger.info(f"Initial resources set: {initial_resources}")
|
|
222
|
+
else:
|
|
223
|
+
logger.warning("No initial resources provided. Skip initializing resources.")
|
|
224
|
+
resources_id = None
|
|
225
|
+
|
|
226
|
+
for epoch in range(self.n_epochs):
|
|
227
|
+
harvest_tasks: List[asyncio.Task[None]] = []
|
|
228
|
+
logger.info(f"Proceeding epoch {epoch + 1}/{self.n_epochs}.")
|
|
229
|
+
for index in train_indices + val_indices:
|
|
230
|
+
logger.info(
|
|
231
|
+
f"Processing index {index}. {len(train_indices)} train indices and {len(val_indices)} val indices in total."
|
|
232
|
+
)
|
|
233
|
+
while True:
|
|
234
|
+
queuing_rollouts = await store.query_rollouts(status_in=["queuing", "requeuing"])
|
|
235
|
+
if len(queuing_rollouts) <= self.max_queue_length:
|
|
236
|
+
# Only enqueue a new rollout when there is at most "max_queue_length" rollout in the queue.
|
|
237
|
+
sample = concatenated_dataset[index]
|
|
238
|
+
mode = "train" if index in train_indices else "val"
|
|
239
|
+
rollout = await store.enqueue_rollout(input=sample, mode=mode, resources_id=resources_id)
|
|
240
|
+
harvest_tasks.append(asyncio.create_task(self._harvest_rollout_spans(rollout.rollout_id)))
|
|
241
|
+
logger.info(f"Enqueued rollout {rollout.rollout_id} in {mode} mode with sample: {sample}")
|
|
242
|
+
break
|
|
243
|
+
else:
|
|
244
|
+
# Sleep a bit and try again later.
|
|
245
|
+
await asyncio.sleep(self.polling_interval)
|
|
246
|
+
|
|
247
|
+
# Wait for all harvest tasks to complete
|
|
248
|
+
logger.info(f"Waiting for {len(harvest_tasks)} harvest tasks to complete...")
|
|
249
|
+
if len(harvest_tasks) > 0:
|
|
250
|
+
await asyncio.gather(*harvest_tasks)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from .adapter import (
|
|
4
|
+
MantisdkGEPAAdapter,
|
|
5
|
+
MantisdkDataInst,
|
|
6
|
+
MantisdkTrajectory,
|
|
7
|
+
MantisdkRolloutOutput,
|
|
8
|
+
)
|
|
9
|
+
from .gepa import GEPA, TEMPLATE_AWARE_REFLECTION_PROMPT
|
|
10
|
+
from .tracing import GEPATracingContext
|
|
11
|
+
|
|
12
|
+
# Re-export the GEPAAdapter from the gepa library for convenience
|
|
13
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import GEPAAdapter
|
|
14
|
+
|
|
15
|
+
# GEPA-specific call type decorators for tagging LLM calls
|
|
16
|
+
# Usage: @gepa.judge, @gepa.agent, @gepa.reflection
|
|
17
|
+
from mantisdk.types.tracing import call_type_decorator
|
|
18
|
+
|
|
19
|
+
agent = call_type_decorator("agent-call")
|
|
20
|
+
"""Decorator to tag LLM calls as agent calls.
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
>>> @gepa.agent
|
|
24
|
+
>>> def run_agent(client, prompt):
|
|
25
|
+
... return client.chat.completions.create(...) # Tagged as "agent-call"
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
judge = call_type_decorator("judge-call")
|
|
29
|
+
"""Decorator to tag LLM calls as judge/grading calls.
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
>>> @gepa.judge
|
|
33
|
+
>>> def grade_response(client, response, expected):
|
|
34
|
+
... return client.chat.completions.parse(...) # Tagged as "judge-call"
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
reflection = call_type_decorator("reflection-call")
|
|
38
|
+
"""Decorator to tag LLM calls as reflection/optimization calls.
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
>>> @gepa.reflection
|
|
42
|
+
>>> def reflect_on_prompts(client, feedback):
|
|
43
|
+
... return client.chat.completions.create(...) # Tagged as "reflection-call"
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
__all__ = [
|
|
47
|
+
"GEPA",
|
|
48
|
+
"GEPAAdapter",
|
|
49
|
+
"MantisdkGEPAAdapter",
|
|
50
|
+
"MantisdkDataInst",
|
|
51
|
+
"MantisdkTrajectory",
|
|
52
|
+
"MantisdkRolloutOutput",
|
|
53
|
+
"GEPATracingContext",
|
|
54
|
+
"TEMPLATE_AWARE_REFLECTION_PROMPT",
|
|
55
|
+
# Call type decorators
|
|
56
|
+
"agent",
|
|
57
|
+
"judge",
|
|
58
|
+
"reflection",
|
|
59
|
+
]
|