mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,196 @@
1
+ """
2
+ Utility functions for graceful stopping of GEPA runs.
3
+ """
4
+
5
+ import os
6
+ import signal
7
+ import time
8
+ from typing import Literal, Protocol, runtime_checkable
9
+
10
+ from mantisdk.algorithm.gepa.lib.core.state import GEPAState
11
+
12
+
13
+ @runtime_checkable
14
+ class StopperProtocol(Protocol):
15
+ """
16
+ Protocol for stop condition objects.
17
+
18
+ A stopper is a callable object that returns True when the optimization should stop.
19
+ """
20
+
21
+ def __call__(self, gepa_state: GEPAState) -> bool:
22
+ """
23
+ Check if the optimization should stop.
24
+
25
+ Args:
26
+ gepa_state: The current GEPA state containing optimization information
27
+
28
+ Returns:
29
+ True if the optimization should stop, False otherwise.
30
+ """
31
+ ...
32
+
33
+
34
+ class TimeoutStopCondition(StopperProtocol):
35
+ """Stop callback that stops after a specified timeout."""
36
+
37
+ def __init__(self, timeout_seconds: float):
38
+ self.timeout_seconds = timeout_seconds
39
+ self.start_time = time.time()
40
+
41
+ def __call__(self, gepa_state: GEPAState) -> bool:
42
+ # return true if timeout has been reached
43
+ return time.time() - self.start_time > self.timeout_seconds
44
+
45
+
46
+ class FileStopper(StopperProtocol):
47
+ """
48
+ Stop callback that stops when a specific file exists.
49
+ """
50
+
51
+ def __init__(self, stop_file_path: str):
52
+ self.stop_file_path = stop_file_path
53
+
54
+ def __call__(self, gepa_state: GEPAState) -> bool:
55
+ # returns true if stop file exists
56
+ return os.path.exists(self.stop_file_path)
57
+
58
+ def remove_stop_file(self):
59
+ # remove the stop file
60
+ if os.path.exists(self.stop_file_path):
61
+ os.remove(self.stop_file_path)
62
+
63
+
64
+ class ScoreThresholdStopper(StopperProtocol):
65
+ """
66
+ Stop callback that stops when a score threshold is reached.
67
+ """
68
+
69
+ def __init__(self, threshold: float):
70
+ self.threshold = threshold
71
+
72
+ def __call__(self, gepa_state: GEPAState) -> bool:
73
+ # return true if score threshold is reached
74
+ try:
75
+ current_best_score = (
76
+ max(gepa_state.program_full_scores_val_set) if gepa_state.program_full_scores_val_set else 0.0
77
+ )
78
+ return current_best_score >= self.threshold
79
+ except Exception:
80
+ return False
81
+
82
+
83
+ class NoImprovementStopper(StopperProtocol):
84
+ """
85
+ Stop callback that stops after a specified number of iterations without improvement.
86
+ """
87
+
88
+ def __init__(self, max_iterations_without_improvement: int):
89
+ self.max_iterations_without_improvement = max_iterations_without_improvement
90
+ self.best_score = float("-inf")
91
+ self.iterations_without_improvement = 0
92
+
93
+ def __call__(self, gepa_state: GEPAState) -> bool:
94
+ # return true if max iterations without improvement reached
95
+ try:
96
+ current_score = (
97
+ max(gepa_state.program_full_scores_val_set) if gepa_state.program_full_scores_val_set else 0.0
98
+ )
99
+ if current_score > self.best_score:
100
+ self.best_score = current_score
101
+ self.iterations_without_improvement = 0
102
+ else:
103
+ self.iterations_without_improvement += 1
104
+
105
+ return self.iterations_without_improvement >= self.max_iterations_without_improvement
106
+ except Exception:
107
+ return False
108
+
109
+ def reset(self):
110
+ """Reset the counter (useful when manually improving the score)."""
111
+ self.iterations_without_improvement = 0
112
+
113
+
114
+ class SignalStopper(StopperProtocol):
115
+ """Stop callback that stops when a signal is received."""
116
+
117
+ def __init__(self, signals=None):
118
+ self.signals = signals or [signal.SIGINT, signal.SIGTERM]
119
+ self._stop_requested = False
120
+ self._original_handlers = {}
121
+ self._setup_signal_handlers()
122
+
123
+ def _setup_signal_handlers(self):
124
+ """Set up signal handlers for graceful shutdown."""
125
+
126
+ def signal_handler(signum, frame):
127
+ self._stop_requested = True
128
+
129
+ # Store original handlers and set new ones
130
+ for sig in self.signals:
131
+ try:
132
+ self._original_handlers[sig] = signal.signal(sig, signal_handler)
133
+ except (OSError, ValueError):
134
+ # Signal not available on this platform
135
+ pass
136
+
137
+ def __call__(self, gepa_state: GEPAState) -> bool:
138
+ # return true if a signal was received
139
+ return self._stop_requested
140
+
141
+ def cleanup(self):
142
+ """Restore original signal handlers."""
143
+ for sig, handler in self._original_handlers.items():
144
+ try:
145
+ signal.signal(sig, handler)
146
+ except (OSError, ValueError):
147
+ pass
148
+
149
+
150
+ class MaxTrackedCandidatesStopper(StopperProtocol):
151
+ """
152
+ Stop callback that stops after a maximum number of tracked candidates.
153
+ """
154
+
155
+ def __init__(self, max_tracked_candidates: int):
156
+ self.max_tracked_candidates = max_tracked_candidates
157
+
158
+ def __call__(self, gepa_state: GEPAState) -> bool:
159
+ # return true if max tracked candidates reached
160
+ return len(gepa_state.program_candidates) >= self.max_tracked_candidates
161
+
162
+
163
+ class MaxMetricCallsStopper(StopperProtocol):
164
+ """
165
+ Stop callback that stops after a maximum number of metric calls.
166
+ """
167
+
168
+ def __init__(self, max_metric_calls: int):
169
+ self.max_metric_calls = max_metric_calls
170
+
171
+ def __call__(self, gepa_state: GEPAState) -> bool:
172
+ # return true if max metric calls reached
173
+ return gepa_state.total_num_evals >= self.max_metric_calls
174
+
175
+
176
+ class CompositeStopper(StopperProtocol):
177
+ """
178
+ Stop callback that combines multiple stopping conditions.
179
+
180
+ Allows combining several stoppers and stopping when any or all of them are triggered.
181
+ """
182
+
183
+ def __init__(self, *stoppers: StopperProtocol, mode: Literal["any", "all"] = "any"):
184
+ # initialize composite stopper
185
+
186
+ self.stoppers = stoppers
187
+ self.mode = mode
188
+
189
+ def __call__(self, gepa_state: GEPAState) -> bool:
190
+ # return true if stopping condition is met
191
+ if self.mode == "any":
192
+ return any(stopper(gepa_state) for stopper in self.stoppers)
193
+ elif self.mode == "all":
194
+ return all(stopper(gepa_state) for stopper in self.stoppers)
195
+ else:
196
+ raise ValueError(f"Unknown mode: {self.mode}")
@@ -0,0 +1,105 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ """GEPA-specific tracing context for detailed execution tracking.
4
+
5
+ This module provides a context class that tracks GEPA's execution state
6
+ (generation, phase, candidate, batch) to enable rich tagging of traces.
7
+ """
8
+
9
+ import uuid
10
+ from dataclasses import dataclass, field
11
+ from typing import List, Optional, Set
12
+
13
+
14
+ @dataclass
15
+ class GEPATracingContext:
16
+ """Tracks execution state for detailed tracing in GEPA optimization.
17
+
18
+ This class maintains state about the current phase of GEPA execution,
19
+ generation/iteration number, and batch counts to enable rich tagging
20
+ of traces for filtering and analysis in Mantis.
21
+
22
+ GEPA-specific phases:
23
+ - "train-eval": Evaluating candidates on training data
24
+ - "validation-eval": Evaluating candidates on validation data
25
+ - "reflection": LLM reflection to improve prompts (distinct from validation!)
26
+
27
+ Example:
28
+ >>> ctx = GEPATracingContext()
29
+ >>> ctx.generation
30
+ 0
31
+ >>> ctx.session_id # Auto-generated for grouping traces
32
+ 'gepa-abc123def456'
33
+ >>> ctx.next_generation()
34
+ >>> ctx.generation
35
+ 1
36
+ >>> batch_id = ctx.next_batch()
37
+ >>> batch_id
38
+ 'batch-1'
39
+
40
+ Attributes:
41
+ generation: Current generation/iteration number (0-indexed).
42
+ phase: Current execution phase.
43
+ candidate_id: Short hash of the current candidate being evaluated.
44
+ batch_count: Number of batches processed in current generation.
45
+ training_item_ids: Set of item IDs seen during training (for validation detection).
46
+ session_id: Unique session identifier for grouping all traces in this GEPA run.
47
+ """
48
+
49
+ generation: int = 0
50
+ phase: str = "train-eval"
51
+ candidate_id: Optional[str] = None
52
+ batch_count: int = 0
53
+ training_item_ids: Set[str] = field(default_factory=set)
54
+ session_id: str = field(default_factory=lambda: f"gepa-{uuid.uuid4().hex[:12]}")
55
+
56
+ def next_batch(self) -> str:
57
+ """Increment batch count and return batch identifier.
58
+
59
+ Returns:
60
+ Batch identifier string (e.g., "batch-1").
61
+ """
62
+ self.batch_count += 1
63
+ return f"batch-{self.batch_count}"
64
+
65
+ def set_phase(self, phase: str) -> None:
66
+ """Set the current execution phase.
67
+
68
+ Args:
69
+ phase: Phase name (e.g., "train-eval", "validation-eval", "reflection").
70
+ """
71
+ self.phase = phase
72
+
73
+ def next_generation(self) -> None:
74
+ """Increment generation counter and reset batch count."""
75
+ self.generation += 1
76
+ self.batch_count = 0
77
+
78
+ def set_candidate(self, candidate_id: str) -> None:
79
+ """Set the current candidate identifier.
80
+
81
+ Args:
82
+ candidate_id: Short hash or identifier for the candidate.
83
+ """
84
+ self.candidate_id = candidate_id
85
+
86
+ def register_training_items(self, item_ids: List[str]) -> None:
87
+ """Register item IDs as training data for validation detection.
88
+
89
+ Args:
90
+ item_ids: List of item IDs from the training batch.
91
+ """
92
+ self.training_item_ids.update(item_ids)
93
+
94
+ def is_validation_batch(self, item_ids: List[str]) -> bool:
95
+ """Check if a batch contains validation items (not in training set).
96
+
97
+ Args:
98
+ item_ids: List of item IDs from the batch.
99
+
100
+ Returns:
101
+ True if any item is not in the training set.
102
+ """
103
+ if not self.training_item_ids:
104
+ return False
105
+ return any(item_id not in self.training_item_ids for item_id in item_ids)
@@ -0,0 +1,177 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import logging
7
+ import random
8
+ from collections.abc import Coroutine
9
+ from typing import (
10
+ TYPE_CHECKING,
11
+ Any,
12
+ Callable,
13
+ Concatenate,
14
+ Iterator,
15
+ List,
16
+ Literal,
17
+ Optional,
18
+ ParamSpec,
19
+ Sequence,
20
+ TypeVar,
21
+ overload,
22
+ )
23
+
24
+ from mantisdk.types import Dataset
25
+
26
+ if TYPE_CHECKING:
27
+ from mantisdk.llm_proxy import LLMProxy
28
+ from mantisdk.store.base import LightningStore
29
+
30
+ from .base import Algorithm
31
+
32
+ T_task = TypeVar("T_task")
33
+ T_algo = TypeVar("T_algo", bound="Algorithm")
34
+
35
+ P = ParamSpec("P")
36
+ R = TypeVar("R")
37
+
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ def batch_iter_over_dataset(dataset: Dataset[T_task], batch_size: int) -> Iterator[Sequence[T_task]]:
43
+ """
44
+ Create an infinite iterator that yields batches from the dataset.
45
+
46
+ When batch_size >= dataset size, yields the entire shuffled dataset repeatedly.
47
+ When batch_size < dataset size, yields batches of the specified size, reshuffling
48
+ after each complete pass through the dataset.
49
+
50
+ Args:
51
+ dataset: The dataset to iterate over.
52
+ batch_size: The desired batch size.
53
+
54
+ Yields:
55
+ Sequences of tasks from the dataset. Each task appears at most once per epoch.
56
+ """
57
+ if batch_size >= len(dataset):
58
+ while True:
59
+ dataset_copy = [dataset[i] for i in range(len(dataset))]
60
+ random.shuffle(dataset_copy)
61
+ yield dataset_copy
62
+
63
+ else:
64
+ current_batch: List[int] = []
65
+ while True:
66
+ indices = list(range(len(dataset)))
67
+ random.shuffle(indices)
68
+ for index in indices:
69
+ if index in current_batch:
70
+ continue
71
+ current_batch.append(index)
72
+ if len(current_batch) == batch_size:
73
+ yield [dataset[index] for index in current_batch]
74
+ current_batch = []
75
+
76
+
77
+ def with_store(
78
+ func: Callable[Concatenate[T_algo, LightningStore, P], Coroutine[Any, Any, R]],
79
+ ) -> Callable[Concatenate[T_algo, P], Coroutine[Any, Any, R]]:
80
+ """Inject the algorithm's `LightningStore` into coroutine methods.
81
+
82
+ The decorator calls `Algorithm.get_store()` once per invocation and passes the
83
+ resulting store as an explicit argument to the wrapped coroutine. Decorated
84
+ methods therefore receive the resolved store even when invoked by helper
85
+ utilities rather than directly by the algorithm.
86
+
87
+ Args:
88
+ func: The coroutine that expects `(self, store, *args, **kwargs)`.
89
+
90
+ Returns:
91
+ A coroutine wrapper that automatically retrieves the store and forwards it
92
+ to `func`.
93
+ """
94
+
95
+ @functools.wraps(func)
96
+ async def wrapper(self: T_algo, *args: P.args, **kwargs: P.kwargs) -> R:
97
+ store = self.get_store()
98
+ return await func(self, store, *args, **kwargs)
99
+
100
+ return wrapper
101
+
102
+
103
+ @overload
104
+ def with_llm_proxy(
105
+ required: Literal[False] = False,
106
+ auto_start: bool = True,
107
+ ) -> Callable[
108
+ [Callable[Concatenate[T_algo, Optional[LLMProxy], P], Coroutine[Any, Any, R]]],
109
+ Callable[Concatenate[T_algo, P], Coroutine[Any, Any, R]],
110
+ ]: ...
111
+
112
+
113
+ @overload
114
+ def with_llm_proxy(
115
+ required: Literal[True],
116
+ auto_start: bool = True,
117
+ ) -> Callable[
118
+ [Callable[Concatenate[T_algo, LLMProxy, P], Coroutine[Any, Any, R]]],
119
+ Callable[Concatenate[T_algo, P], Coroutine[Any, Any, R]],
120
+ ]: ...
121
+
122
+
123
+ def with_llm_proxy(
124
+ required: bool = False,
125
+ auto_start: bool = True,
126
+ ) -> Callable[
127
+ [Callable[..., Coroutine[Any, Any, Any]]],
128
+ Callable[..., Coroutine[Any, Any, Any]],
129
+ ]:
130
+ """Resolve and optionally lifecycle-manage the configured LLM proxy.
131
+
132
+ Args:
133
+ required: When True, raises `ValueError` if the algorithm does not have an
134
+ [`LLMProxy`][mantisdk.LLMProxy] set. When False, the wrapped coroutine receives
135
+ `None` if no proxy is available.
136
+ auto_start: When True, [`LLMProxy.start()`][mantisdk.LLMProxy.start] is invoked if the proxy is not
137
+ already running before calling `func` and [`LLMProxy.stop()`][mantisdk.LLMProxy.stop] is
138
+ called afterwards.
139
+
140
+ Returns:
141
+ A decorator that injects the [`LLMProxy`][mantisdk.LLMProxy] (or `None`) as the first
142
+ argument after `self` and manages automatic startup/shutdown when requested.
143
+ """
144
+
145
+ def decorator(
146
+ func: Callable[..., Coroutine[Any, Any, Any]],
147
+ ) -> Callable[..., Coroutine[Any, Any, Any]]:
148
+ @functools.wraps(func)
149
+ async def wrapper(self: Algorithm, *args: Any, **kwargs: Any) -> Any:
150
+ llm_proxy = self.get_llm_proxy()
151
+
152
+ if required and llm_proxy is None:
153
+ raise ValueError(
154
+ "LLM proxy is required but not configured. Call set_llm_proxy() before using this method."
155
+ )
156
+
157
+ auto_started = False
158
+ if auto_start and llm_proxy is not None:
159
+ if llm_proxy.is_running():
160
+ logger.info("Proxy is already running, skipping start")
161
+ else:
162
+ logger.info("Starting proxy, managed by the algorithm")
163
+ await llm_proxy.start()
164
+ auto_started = True
165
+
166
+ try:
167
+ # At type level, overloads guarantee that if `required=True`
168
+ # then `func` expects a non-optional LLMProxy.
169
+ return await func(self, llm_proxy, *args, **kwargs)
170
+ finally:
171
+ if auto_started and llm_proxy is not None:
172
+ logger.info("Stopping proxy, managed by the algorithm")
173
+ await llm_proxy.stop()
174
+
175
+ return wrapper
176
+
177
+ return decorator
@@ -0,0 +1,5 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from .interface import VERL
4
+
5
+ __all__ = ["VERL"]
@@ -0,0 +1,202 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any, Optional, Type
6
+
7
+ from hydra import compose, initialize
8
+ from omegaconf import OmegaConf
9
+
10
+ from mantisdk.algorithm.base import Algorithm
11
+ from mantisdk.client import MantisdkClient
12
+ from mantisdk.types import Dataset
13
+ from mantisdk.verl.entrypoint import run_ppo # type: ignore
14
+
15
+ if TYPE_CHECKING:
16
+ from mantisdk.verl.daemon import AgentModeDaemon
17
+ from mantisdk.verl.trainer import MantisdkTrainer
18
+
19
+
20
+ class VERL(Algorithm):
21
+ """VERL-powered algorithm that delegates training to the VERL PPO runner.
22
+
23
+ !!! warning
24
+ Advanced customisation currently requires copying the VERL source and
25
+ modifying it directly. Native hooks for overriding training behaviour
26
+ will land in a future release.
27
+
28
+ Args:
29
+ config: Dictionary mirroring the overrides passed to the VERL CLI. The
30
+ overrides are merged with VERL's packaged defaults via Hydra before
31
+ launching training.
32
+ trainer_cls: Optional override for the trainer class. Experimental.
33
+ daemon_cls: Optional override for the daemon class. Experimental.
34
+
35
+ !!! note "Trajectory aggregation (experimental)"
36
+
37
+ Trajectory-level aggregation merges an entire multi-turn rollout into a single,
38
+ masked training sample so GPU time is spent once per trajectory rather than N times
39
+ per turn. Enable it via:
40
+
41
+ ```python
42
+ config["mantisdk"]["trace_aggregator"] = {
43
+ "level": "trajectory",
44
+ "trajectory_max_prompt_length": 4096,
45
+ "trajectory_max_response_length": 34384,
46
+ }
47
+ ```
48
+
49
+ Keep conversations structured (message lists rather than manual string
50
+ concatenation) so prefix matching can stitch traces. `trajectory_max_prompt_length`
51
+ should be set to the maximum length of the prompt for the first turn, and
52
+ `trajectory_max_response_length` should be set to the maximum cumulative
53
+ length of agent responses in the full trajectory.
54
+ Toggle `debug=True` plus `mismatch_log_dir` when you need to inspect
55
+ retokenization or chat-template mismatches. See
56
+ [this blog post](https://mantisdk.github.io/posts/trajectory_level_aggregation/)
57
+ for more details.
58
+
59
+ Examples:
60
+ ```python
61
+ from mantisdk.algorithm.verl import VERL
62
+
63
+ algorithm = VERL(
64
+ config={
65
+ "algorithm": {
66
+ "adv_estimator": "grpo",
67
+ "use_kl_in_reward": False,
68
+ },
69
+ "data": {
70
+ "train_batch_size": 32,
71
+ "max_prompt_length": 4096,
72
+ "max_response_length": 2048,
73
+ },
74
+ "actor_rollout_ref": {
75
+ "rollout": {
76
+ "tensor_model_parallel_size": 1,
77
+ "n": 4,
78
+ "log_prob_micro_batch_size_per_gpu": 4,
79
+ "multi_turn": {"format": "hermes"},
80
+ "name": "vllm",
81
+ "gpu_memory_utilization": 0.6,
82
+ },
83
+ "actor": {
84
+ "ppo_mini_batch_size": 32,
85
+ "ppo_micro_batch_size_per_gpu": 4,
86
+ "optim": {"lr": 1e-6},
87
+ "use_kl_loss": False,
88
+ "kl_loss_coef": 0.0,
89
+ "entropy_coeff": 0,
90
+ "clip_ratio_low": 0.2,
91
+ "clip_ratio_high": 0.3,
92
+ "fsdp_config": {
93
+ "param_offload": True,
94
+ "optimizer_offload": True,
95
+ },
96
+ },
97
+ "ref": {
98
+ "log_prob_micro_batch_size_per_gpu": 8,
99
+ "fsdp_config": {"param_offload": True},
100
+ },
101
+ "model": {
102
+ "path": "Qwen/Qwen2.5-1.5B-Instruct",
103
+ "use_remove_padding": True,
104
+ "enable_gradient_checkpointing": True,
105
+ },
106
+ },
107
+ "trainer": {
108
+ "n_gpus_per_node": 1,
109
+ "val_before_train": True,
110
+ "critic_warmup": 0,
111
+ "logger": ["console", "wandb"],
112
+ "project_name": "Mantisdk",
113
+ "experiment_name": "calc_x",
114
+ "nnodes": 1,
115
+ "save_freq": 64,
116
+ "test_freq": 32,
117
+ "total_epochs": 2,
118
+ },
119
+ }
120
+ )
121
+ trainer.fit(algorithm, train_dataset=my_train_dataset)
122
+ ```
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ config: dict[str, Any],
128
+ trainer_cls: Optional[Type[MantisdkTrainer]] = None,
129
+ daemon_cls: Optional[Type[AgentModeDaemon]] = None,
130
+ ):
131
+ super().__init__()
132
+
133
+ # Compose the base config exactly like your decorator:
134
+ with initialize(version_base=None, config_path="pkg://mantisdk/verl"):
135
+ base_cfg = compose(config_name="config")
136
+
137
+ # Merge your dict overrides
138
+ override_conf = OmegaConf.create(config)
139
+ # Allow adding new fields
140
+ OmegaConf.set_struct(base_cfg, False)
141
+ self.config = OmegaConf.merge(base_cfg, override_conf)
142
+ self.trainer_cls = trainer_cls
143
+ self.daemon_cls = daemon_cls
144
+
145
+ def run(
146
+ self,
147
+ train_dataset: Optional[Dataset[Any]] = None,
148
+ val_dataset: Optional[Dataset[Any]] = None,
149
+ ) -> None:
150
+ """Launch the VERL PPO entrypoint with the configured runtime context.
151
+
152
+ Args:
153
+ train_dataset: Optional dataset forwarded to VERL for training.
154
+ val_dataset: Optional dataset forwarded to VERL for evaluation.
155
+
156
+ Raises:
157
+ ValueError: If required dependencies such as the store, LLM proxy, or
158
+ adapter have been garbage-collected when using the V1 execution
159
+ mode.
160
+ """
161
+ from mantisdk.verl.daemon import AgentModeDaemon
162
+ from mantisdk.verl.trainer import MantisdkTrainer
163
+
164
+ trainer_cls = self.trainer_cls or MantisdkTrainer
165
+ daemon_cls = self.daemon_cls or AgentModeDaemon
166
+ try:
167
+ store = self.get_store()
168
+ except Exception:
169
+ print("Store is not set. Assuming v0 execution mode.")
170
+ run_ppo(
171
+ self.config,
172
+ train_dataset=train_dataset,
173
+ val_dataset=val_dataset,
174
+ store=None,
175
+ llm_proxy=None,
176
+ adapter=None,
177
+ trainer_cls=trainer_cls,
178
+ daemon_cls=daemon_cls,
179
+ )
180
+ else:
181
+ print("Store is set. Assuming v1 execution mode.")
182
+ llm_proxy = self.get_llm_proxy()
183
+ adapter = self.get_adapter()
184
+ run_ppo(
185
+ self.config,
186
+ train_dataset=train_dataset,
187
+ val_dataset=val_dataset,
188
+ store=store,
189
+ llm_proxy=llm_proxy,
190
+ adapter=adapter,
191
+ trainer_cls=trainer_cls,
192
+ daemon_cls=daemon_cls,
193
+ )
194
+
195
+ def get_client(self) -> MantisdkClient:
196
+ """Create a client bound to the VERL-managed Mantisdk server.
197
+
198
+ Deprecated:
199
+ Since v0.2.
200
+ """
201
+ port = self.config.mantisdk.port
202
+ return MantisdkClient(endpoint=f"http://localhost:{port}")