mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,250 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ from datetime import datetime
8
+ from typing import TYPE_CHECKING, Any, List, Literal, Optional
9
+
10
+ from mantisdk.types import Attempt, Dataset, Rollout, RolloutStatus, Span
11
+
12
+ from .base import Algorithm
13
+ from .utils import with_llm_proxy, with_store
14
+
15
+ if TYPE_CHECKING:
16
+ from mantisdk.llm_proxy import LLMProxy
17
+ from mantisdk.store.base import LightningStore
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ __all__ = ["FastAlgorithm", "Baseline"]
22
+
23
+
24
+ class FastAlgorithm(Algorithm):
25
+ """Base class for lightweight algorithms optimised for developer workflows.
26
+
27
+ Fast algorithms prioritise short feedback loops so an agent developer can run
28
+ small-scale experiments without waiting for long-running training jobs to
29
+ finish.
30
+ """
31
+
32
+
33
+ def _timestamp_to_iso_str(timestamp: float) -> str:
34
+ return datetime.fromtimestamp(timestamp).isoformat()
35
+
36
+
37
+ class Baseline(FastAlgorithm):
38
+ """Reference implementation that streams the full dataset through the rollout queue.
39
+
40
+ The baseline algorithm batches task submissions, waits for each rollout to
41
+ finish, and logs every collected span and reward. It is primarily useful as
42
+ a smoke test for the platform plumbing rather than a performant trainer.
43
+
44
+ The baseline algorithm will auto-start a LLM proxy if one is provided and not yet started.
45
+
46
+ Args:
47
+ n_epochs: Number of dataset passes to execute for both the train and val
48
+ splits during developer experiments.
49
+ train_split: Fraction of the concatenated dataset to treat as training
50
+ data. Must be strictly between 0 and 1.
51
+ polling_interval: Interval, in seconds, to poll the store for queue
52
+ depth and rollout completion.
53
+ max_queue_length: Number of rollouts allowed to wait in the queue before
54
+ throttling additional submissions.
55
+ span_verbosity: Level of detail to include when logging span metadata.
56
+
57
+ Raises:
58
+ ValueError: If `train_split` falls outside the `(0, 1)` interval.
59
+
60
+ Examples:
61
+ ```python
62
+ from mantisdk.algorithm.fast import Baseline
63
+
64
+ algorithm = Baseline(n_epochs=2, train_split=0.8, span_verbosity="key_values")
65
+ trainer.fit(algorithm, train_dataset=my_train, val_dataset=my_val)
66
+ ```
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ *,
72
+ n_epochs: int = 1,
73
+ train_split: float = 0.5,
74
+ polling_interval: float = 5.0,
75
+ max_queue_length: int = 4,
76
+ span_verbosity: Literal["keys", "key_values", "none"] = "keys",
77
+ ) -> None:
78
+ super().__init__()
79
+ self.n_epochs = n_epochs
80
+ self.train_split = train_split
81
+ self.polling_interval = polling_interval
82
+ self.max_queue_length = max_queue_length
83
+ self.span_verbosity = span_verbosity
84
+ if not (0.0 < self.train_split < 1.0):
85
+ raise ValueError("train_split must be between 0 and 1.")
86
+
87
+ self._finished_rollout_count = 0
88
+
89
+ def _span_to_string(self, rollout_id: str, attempt: Attempt, span: Span) -> str:
90
+ """Format a span for logging based on the configured verbosity."""
91
+ if self.span_verbosity == "none":
92
+ return ""
93
+
94
+ prefix_msg = f"[Rollout {rollout_id} | Attempt {attempt.attempt_id} | Span {span.span_id}] #{span.sequence_id} ({span.name}) "
95
+ elapsed = f"{span.end_time - span.start_time:.2f}" if span.start_time and span.end_time else "unknown"
96
+
97
+ msg = (
98
+ prefix_msg
99
+ + f"From {_timestamp_to_iso_str(span.start_time) if span.start_time else 'unknown'}, "
100
+ + f"to {_timestamp_to_iso_str(span.end_time) if span.end_time else 'unknown'}, "
101
+ + f"{elapsed} seconds. "
102
+ )
103
+ if self.span_verbosity == "key_values":
104
+ msg += f"Attributes: {span.attributes}"
105
+ else:
106
+ msg += f"Attribute keys: {list(span.attributes.keys())}"
107
+ return msg
108
+
109
+ async def _handle_rollout_finish(self, rollout: Rollout) -> None:
110
+ """Log attempt metadata and emit adapted traces when a rollout ends."""
111
+ store = self.get_store()
112
+
113
+ rollout_id = rollout.rollout_id
114
+ rollout_end_time = rollout.end_time or asyncio.get_event_loop().time()
115
+ logger.info(
116
+ f"[Rollout {rollout_id}] Finished with status {rollout.status} in {rollout_end_time - rollout.start_time:.2f} seconds."
117
+ )
118
+
119
+ # Logs all the attempts and their corresponding spans
120
+ attempts = await store.query_attempts(rollout_id)
121
+ for attempt in attempts:
122
+ logger.info(
123
+ "[Rollout %s | Attempt %s] ID: %s. Status: %s. Worker: %s",
124
+ rollout_id,
125
+ attempt.sequence_id,
126
+ attempt.attempt_id,
127
+ attempt.status,
128
+ attempt.worker_id,
129
+ )
130
+ spans = await store.query_spans(rollout_id=rollout_id)
131
+ for span in spans:
132
+ if self.span_verbosity != "none":
133
+ logger.info(self._span_to_string(rollout.rollout_id, attempt, span))
134
+
135
+ # Attempts to adapt the spans using the adapter if provided
136
+ try:
137
+ adapter = self.get_adapter()
138
+ except ValueError:
139
+ logger.warning("No adapter set for MockAlgorithm. Skipping trace adaptation.")
140
+ adapter = None
141
+ if adapter is not None:
142
+ spans = await store.query_spans(rollout_id=rollout_id, attempt_id="latest")
143
+ transformed_data = adapter.adapt(spans)
144
+ logger.info(f"[Rollout {rollout_id}] Adapted data: {transformed_data}")
145
+
146
+ async def _enqueue_rollouts(
147
+ self, dataset: Dataset[Any], train_indices: List[int], val_indices: List[int], resources_id: str
148
+ ) -> None:
149
+ """Submit rollouts while respecting the maximum queue length."""
150
+ store = self.get_store()
151
+
152
+ for index in train_indices + val_indices:
153
+ queuing_rollouts = await store.query_rollouts(status_in=["queuing", "requeuing"])
154
+ if len(queuing_rollouts) <= 1:
155
+ # Only enqueue a new rollout when there is at most 1 rollout in the queue.
156
+ sample = dataset[index]
157
+ mode = "train" if index in train_indices else "val"
158
+ rollout = await store.enqueue_rollout(input=sample, mode=mode, resources_id=resources_id)
159
+ logger.info(f"[Rollout {rollout.rollout_id}] Enqueued in {mode} mode with sample: {sample}")
160
+ await asyncio.sleep(self.polling_interval)
161
+
162
+ async def _harvest_rollout_spans(self, rollout_id: str):
163
+ """Poll rollout status updates until completion and log transitions."""
164
+ store = self.get_store()
165
+ last_status: Optional[RolloutStatus] = None
166
+
167
+ while True:
168
+ rollout = await store.get_rollout_by_id(rollout_id)
169
+ if rollout is not None:
170
+ if rollout.status in ["succeeded", "failed", "cancelled"]:
171
+ # Rollout is finished, log all the data.
172
+ await self._handle_rollout_finish(rollout)
173
+ # We are done here.
174
+ self._finished_rollout_count += 1
175
+ logger.info(f"Finished {self._finished_rollout_count} rollouts.")
176
+ break
177
+
178
+ if last_status != rollout.status:
179
+ if last_status is not None:
180
+ logger.info(f"[Rollout {rollout_id}] Status changed to {rollout.status}.")
181
+ else:
182
+ logger.info(f"[Rollout {rollout_id}] Status is initialized to {rollout.status}.")
183
+ last_status = rollout.status
184
+
185
+ else:
186
+ logger.debug(f"[Rollout {rollout_id}] Status is still {rollout.status}.")
187
+
188
+ await asyncio.sleep(self.polling_interval)
189
+
190
+ @with_llm_proxy()
191
+ @with_store
192
+ async def run(
193
+ self,
194
+ store: LightningStore, # Injected by decorator - callers should not provide this parameter
195
+ llm_proxy: Optional[LLMProxy], # Injected by decorator - callers should not provide this parameter
196
+ train_dataset: Optional[Dataset[Any]] = None,
197
+ val_dataset: Optional[Dataset[Any]] = None,
198
+ ) -> None:
199
+ """Execute the baseline loop across the provided datasets."""
200
+ train_dataset_length = len(train_dataset) if train_dataset is not None else 0
201
+ val_dataset_length = len(val_dataset) if val_dataset is not None else 0
202
+ if train_dataset_length == 0 and val_dataset_length == 0:
203
+ logger.error(
204
+ "MockAlgorithm requires at least one dataset. Provide train_dataset or val_dataset before running."
205
+ )
206
+ return
207
+
208
+ concatenated_dataset = [train_dataset[i] for i in range(train_dataset_length) if train_dataset is not None] + [
209
+ val_dataset[i] for i in range(val_dataset_length) if val_dataset is not None
210
+ ]
211
+ train_indices = list(range(0, train_dataset_length))
212
+ val_indices = list(range(train_dataset_length, train_dataset_length + val_dataset_length))
213
+ logger.debug(f"Train indices: {train_indices}")
214
+ logger.debug(f"Val indices: {val_indices}")
215
+
216
+ # Currently we only supports a single resource update at the start.
217
+ initial_resources = self.get_initial_resources()
218
+ if initial_resources is not None:
219
+ resource_update = await store.update_resources("default", initial_resources)
220
+ resources_id = resource_update.resources_id
221
+ logger.info(f"Initial resources set: {initial_resources}")
222
+ else:
223
+ logger.warning("No initial resources provided. Skip initializing resources.")
224
+ resources_id = None
225
+
226
+ for epoch in range(self.n_epochs):
227
+ harvest_tasks: List[asyncio.Task[None]] = []
228
+ logger.info(f"Proceeding epoch {epoch + 1}/{self.n_epochs}.")
229
+ for index in train_indices + val_indices:
230
+ logger.info(
231
+ f"Processing index {index}. {len(train_indices)} train indices and {len(val_indices)} val indices in total."
232
+ )
233
+ while True:
234
+ queuing_rollouts = await store.query_rollouts(status_in=["queuing", "requeuing"])
235
+ if len(queuing_rollouts) <= self.max_queue_length:
236
+ # Only enqueue a new rollout when there is at most "max_queue_length" rollout in the queue.
237
+ sample = concatenated_dataset[index]
238
+ mode = "train" if index in train_indices else "val"
239
+ rollout = await store.enqueue_rollout(input=sample, mode=mode, resources_id=resources_id)
240
+ harvest_tasks.append(asyncio.create_task(self._harvest_rollout_spans(rollout.rollout_id)))
241
+ logger.info(f"Enqueued rollout {rollout.rollout_id} in {mode} mode with sample: {sample}")
242
+ break
243
+ else:
244
+ # Sleep a bit and try again later.
245
+ await asyncio.sleep(self.polling_interval)
246
+
247
+ # Wait for all harvest tasks to complete
248
+ logger.info(f"Waiting for {len(harvest_tasks)} harvest tasks to complete...")
249
+ if len(harvest_tasks) > 0:
250
+ await asyncio.gather(*harvest_tasks)
@@ -0,0 +1,59 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from .adapter import (
4
+ MantisdkGEPAAdapter,
5
+ MantisdkDataInst,
6
+ MantisdkTrajectory,
7
+ MantisdkRolloutOutput,
8
+ )
9
+ from .gepa import GEPA, TEMPLATE_AWARE_REFLECTION_PROMPT
10
+ from .tracing import GEPATracingContext
11
+
12
+ # Re-export the GEPAAdapter from the gepa library for convenience
13
+ from mantisdk.algorithm.gepa.lib.core.adapter import GEPAAdapter
14
+
15
+ # GEPA-specific call type decorators for tagging LLM calls
16
+ # Usage: @gepa.judge, @gepa.agent, @gepa.reflection
17
+ from mantisdk.types.tracing import call_type_decorator
18
+
19
+ agent = call_type_decorator("agent-call")
20
+ """Decorator to tag LLM calls as agent calls.
21
+
22
+ Example:
23
+ >>> @gepa.agent
24
+ >>> def run_agent(client, prompt):
25
+ ... return client.chat.completions.create(...) # Tagged as "agent-call"
26
+ """
27
+
28
+ judge = call_type_decorator("judge-call")
29
+ """Decorator to tag LLM calls as judge/grading calls.
30
+
31
+ Example:
32
+ >>> @gepa.judge
33
+ >>> def grade_response(client, response, expected):
34
+ ... return client.chat.completions.parse(...) # Tagged as "judge-call"
35
+ """
36
+
37
+ reflection = call_type_decorator("reflection-call")
38
+ """Decorator to tag LLM calls as reflection/optimization calls.
39
+
40
+ Example:
41
+ >>> @gepa.reflection
42
+ >>> def reflect_on_prompts(client, feedback):
43
+ ... return client.chat.completions.create(...) # Tagged as "reflection-call"
44
+ """
45
+
46
+ __all__ = [
47
+ "GEPA",
48
+ "GEPAAdapter",
49
+ "MantisdkGEPAAdapter",
50
+ "MantisdkDataInst",
51
+ "MantisdkTrajectory",
52
+ "MantisdkRolloutOutput",
53
+ "GEPATracingContext",
54
+ "TEMPLATE_AWARE_REFLECTION_PROMPT",
55
+ # Call type decorators
56
+ "agent",
57
+ "judge",
58
+ "reflection",
59
+ ]