mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,549 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ # type: ignore
4
+
5
+ from __future__ import annotations
6
+
7
+ import random
8
+ from contextlib import contextmanager
9
+ from copy import deepcopy
10
+ from pprint import pprint
11
+ from typing import Dict, Tuple, Type
12
+
13
+ import numpy as np
14
+ import torch
15
+ import verl
16
+ from codetiming import Timer
17
+ from omegaconf import OmegaConf
18
+ from tqdm import tqdm
19
+ from verl import DataProto
20
+ from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
21
+ from verl.trainer.ppo.core_algos import agg_loss
22
+ from verl.trainer.ppo.metric_utils import (
23
+ _compute_response_info,
24
+ compute_throughout_metrics,
25
+ compute_timing_metrics,
26
+ )
27
+ from verl.trainer.ppo.ray_trainer import (
28
+ AdvantageEstimator,
29
+ RayPPOTrainer,
30
+ apply_kl_penalty,
31
+ compute_advantage,
32
+ compute_response_mask,
33
+ )
34
+ from verl.utils.metric import reduce_metrics
35
+ from verl.utils.tracking import Tracking
36
+
37
+ from mantisdk.adapter import TraceAdapter, TraceToTripletBase
38
+ from mantisdk.llm_proxy import LLMProxy
39
+ from mantisdk.store.base import LightningStore
40
+
41
+ from .daemon import AgentModeDaemon
42
+
43
+ __all__ = [
44
+ "MantisdkTrainer",
45
+ ]
46
+
47
+
48
+ @contextmanager
49
+ def _timer(name: str, timing_raw: Dict[str, float]):
50
+ with Timer(name=name, logger=None) as timer:
51
+ yield
52
+ if name not in timing_raw:
53
+ timing_raw[name] = 0
54
+ timing_raw[name] += timer.last
55
+
56
+
57
+ # This function is adapted from verl.
58
+ # We introduce a new parameter `suffix` to distinguish between metrics computed
59
+ # before and after Mantisdk’s post-processing.
60
+ # - "Before" refers to raw reward and advantage values.
61
+ # - "After" refers to values computed following post-processing, which involves:
62
+ # (1) Dropping prompts that exceed the maximum allowed length.
63
+ # (2) Adjusting the batch size to be a multiple of the mini PPO size.
64
+ # Different suffixes are used to label these two stages accordingly.
65
+ def compute_data_metrics(batch: DataProto, use_critic: bool = True, suffix: str = "") -> Dict[str, Any]:
66
+ """
67
+ Computes various metrics from a batch of data for PPO training.
68
+
69
+ This function calculates metrics related to scores, rewards, advantages, returns, values,
70
+ and sequence lengths from a batch of data. It provides statistical information (mean, max, min)
71
+ for each metric category.
72
+
73
+ Args:
74
+ batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
75
+ use_critic: Whether to include critic-specific metrics. Defaults to True.
76
+
77
+ Returns:
78
+ A dictionary of metrics including:
79
+ - critic/score/mean, max, min: Statistics about sequence scores
80
+ - critic/rewards/mean, max, min: Statistics about sequence rewards
81
+ - critic/advantages/mean, max, min: Statistics about advantages
82
+ - critic/returns/mean, max, min: Statistics about returns
83
+ - critic/values/mean, max, min: Statistics about critic values (if use_critic=True)
84
+ - critic/vf_explained_var: Explained variance of the value function (if use_critic=True)
85
+ - response_length/mean, max, min, clip_ratio: Statistics about response lengths
86
+ - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
87
+ """
88
+ sequence_score = batch.batch["token_level_scores"].sum(-1)
89
+ sequence_reward = batch.batch["token_level_rewards"].sum(-1)
90
+
91
+ advantages = batch.batch["advantages"]
92
+ returns = batch.batch["returns"]
93
+
94
+ max_response_length = batch.batch["responses"].shape[-1]
95
+
96
+ prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
97
+ response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
98
+
99
+ max_prompt_length = prompt_mask.size(-1)
100
+
101
+ response_info = _compute_response_info(batch)
102
+ prompt_length = response_info["prompt_length"]
103
+ response_length = response_info["response_length"]
104
+
105
+ valid_adv = torch.masked_select(advantages, response_mask)
106
+ valid_returns = torch.masked_select(returns, response_mask)
107
+
108
+ if use_critic:
109
+ values = batch.batch["values"]
110
+ valid_values = torch.masked_select(values, response_mask)
111
+ return_diff_var = torch.var(valid_returns - valid_values)
112
+ return_var = torch.var(valid_returns)
113
+
114
+ metrics = {
115
+ # score
116
+ "critic/score/mean" + suffix: torch.mean(sequence_score).detach().item(),
117
+ "critic/score/max" + suffix: torch.max(sequence_score).detach().item(),
118
+ "critic/score/min" + suffix: torch.min(sequence_score).detach().item(),
119
+ # reward
120
+ "critic/rewards/mean" + suffix: torch.mean(sequence_reward).detach().item(),
121
+ "critic/rewards/max" + suffix: torch.max(sequence_reward).detach().item(),
122
+ "critic/rewards/min" + suffix: torch.min(sequence_reward).detach().item(),
123
+ # adv
124
+ "critic/advantages/mean" + suffix: torch.mean(valid_adv).detach().item(),
125
+ "critic/advantages/max" + suffix: torch.max(valid_adv).detach().item(),
126
+ "critic/advantages/min" + suffix: torch.min(valid_adv).detach().item(),
127
+ # returns
128
+ "critic/returns/mean" + suffix: torch.mean(valid_returns).detach().item(),
129
+ "critic/returns/max" + suffix: torch.max(valid_returns).detach().item(),
130
+ "critic/returns/min" + suffix: torch.min(valid_returns).detach().item(),
131
+ **(
132
+ {
133
+ # values
134
+ "critic/values/mean" + suffix: torch.mean(valid_values).detach().item(),
135
+ "critic/values/max" + suffix: torch.max(valid_values).detach().item(),
136
+ "critic/values/min" + suffix: torch.min(valid_values).detach().item(),
137
+ # vf explained var
138
+ "critic/vf_explained_var" + suffix: (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
139
+ }
140
+ if use_critic
141
+ else {}
142
+ ),
143
+ # response length
144
+ "response_length/mean" + suffix: torch.mean(response_length).detach().item(),
145
+ "response_length/max" + suffix: torch.max(response_length).detach().item(),
146
+ "response_length/min" + suffix: torch.min(response_length).detach().item(),
147
+ "response_length/clip_ratio"
148
+ + suffix: torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),
149
+ # prompt length
150
+ "prompt_length/mean" + suffix: torch.mean(prompt_length).detach().item(),
151
+ "prompt_length/max" + suffix: torch.max(prompt_length).detach().item(),
152
+ "prompt_length/min" + suffix: torch.min(prompt_length).detach().item(),
153
+ "prompt_length/clip_ratio"
154
+ + suffix: torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
155
+ }
156
+ return metrics
157
+
158
+
159
+ class MantisdkTrainer(RayPPOTrainer):
160
+ """
161
+ Specialized PPO trainer for agent-based reinforcement learning.
162
+
163
+ This trainer is designed specifically for scenarios where the model interacts with
164
+ external environments, tools, or APIs through an MantisdkServer. It simplifies
165
+ the training loop by removing the complex conditional logic present in the original
166
+ RayPPOTrainer and focusing on the agent mode workflow.
167
+
168
+ Key differences from RayPPOTrainer:
169
+
170
+ 1. Uses AgentModeDaemon for server communication
171
+ 2. Simplified data flow without pop/union operations
172
+ 3. Direct batch processing through agent daemon
173
+ 4. Streamlined validation using agent_mode validation
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ store: LightningStore | None,
179
+ llm_proxy: LLMProxy | None,
180
+ adapter: TraceAdapter | None,
181
+ daemon_cls: Type[AgentModeDaemon],
182
+ **kwargs,
183
+ ):
184
+ super().__init__(**kwargs)
185
+ self.store = store
186
+ self.llm_proxy = llm_proxy
187
+ self.adapter = adapter
188
+ self.daemon_cls = daemon_cls
189
+
190
+ def _validate(self):
191
+ assert len(self.val_dataloader) == 1, "Please set val_batch_size to None for better throughput."
192
+
193
+ test_data = next(iter(self.val_dataloader))
194
+ test_batch = DataProto.from_single_dict(test_data)
195
+
196
+ self.async_rollout_manager.wake_up()
197
+ self.agent_mode_daemon.set_up_data_and_server(
198
+ test_batch.non_tensor_batch,
199
+ self.async_rollout_manager.server_addresses,
200
+ is_train=False,
201
+ )
202
+ self.agent_mode_daemon.run_until_all_finished()
203
+ test_metrics = self.agent_mode_daemon.get_test_metrics()
204
+ self.agent_mode_daemon.clear_data_and_server()
205
+ self.async_rollout_manager.sleep()
206
+ return test_metrics
207
+
208
+ def _compute_reference_log_prob(self, batch: DataProto) -> DataProto:
209
+ """Compute reference log probability using the correct worker based on LoRA configuration.
210
+
211
+ In verl 0.6.0+, when LoRA is detected (indicated by ref_in_actor=True),
212
+ the reference policy is computed by the actor rollout worker instead of a separate
213
+ ref policy worker. This method handles both scenarios by checking the ref_in_actor flag.
214
+ Note: verl sets ref_in_actor=True when it detects LoRA configuration (e.g., lora_rank > 0 or lora_adapter_path is set).
215
+
216
+ Args:
217
+ batch: The data batch to compute reference log probabilities for.
218
+
219
+ Returns:
220
+ DataProto with reference log probabilities added.
221
+
222
+ Raises:
223
+ RuntimeError: If the required worker is not available.
224
+ """
225
+ if getattr(self, "ref_in_actor", False):
226
+ actor_worker = getattr(self, "actor_rollout_wg", None)
227
+ if actor_worker is None:
228
+ raise RuntimeError("actor_rollout_wg is required when ref_in_actor is True.")
229
+ return actor_worker.compute_ref_log_prob(batch)
230
+
231
+ ref_worker = getattr(self, "ref_policy_wg", None)
232
+ if ref_worker is None:
233
+ raise RuntimeError(
234
+ "Reference policy worker was not initialized. "
235
+ "Ensure `use_reference_policy` is enabled and the VERL config exposes the ref worker."
236
+ )
237
+ return ref_worker.compute_ref_log_prob(batch)
238
+
239
+ def _train_step(self, batch_dict: dict) -> dict:
240
+ # Isolate in a separate method to automatically recycle the variables before validation.
241
+ batch: DataProto = DataProto.from_single_dict(batch_dict)
242
+ metrics = {}
243
+ timing_raw = {}
244
+
245
+ with _timer("step", timing_raw):
246
+
247
+ # When agent mode is enabled, we read the batch as it is.
248
+ gen_batch = batch
249
+
250
+ # generate a batch
251
+ with _timer("gen", timing_raw):
252
+ self.async_rollout_manager.wake_up()
253
+ self.agent_mode_daemon.set_up_data_and_server(
254
+ gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses
255
+ )
256
+ self.agent_mode_daemon.run_until_all_finished()
257
+ batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch(
258
+ max_prompt_length=(
259
+ self.config.mantisdk.trace_aggregator.trajectory_max_prompt_length
260
+ if self.config.mantisdk.trace_aggregator.level.startswith("trajectory")
261
+ else self.config.data.max_prompt_length
262
+ ),
263
+ max_response_length=(
264
+ self.config.mantisdk.trace_aggregator.trajectory_max_response_length
265
+ if self.config.mantisdk.trace_aggregator.level.startswith("trajectory")
266
+ else self.config.data.max_response_length
267
+ ),
268
+ device=gen_batch.batch["fake_ids"].device,
269
+ global_steps=self.global_steps,
270
+ )
271
+ metrics.update(agent_metrics)
272
+ self.agent_mode_daemon.clear_data_and_server()
273
+ self.async_rollout_manager.sleep()
274
+
275
+ if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
276
+ with _timer("gen_max", timing_raw):
277
+ gen_baseline_batch = deepcopy(gen_batch)
278
+ gen_baseline_batch.meta_info["do_sample"] = False
279
+ gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
280
+
281
+ batch = batch.union(gen_baseline_output)
282
+ reward_baseline_tensor = self.reward_fn(batch)
283
+ reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
284
+
285
+ batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
286
+
287
+ batch.batch["reward_baselines"] = reward_baseline_tensor
288
+
289
+ del gen_baseline_batch, gen_baseline_output
290
+
291
+ # uid is used for algorithm like GRPO, should be aligned to data id
292
+ batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"]
293
+
294
+ if "response_mask" not in batch.batch:
295
+ batch.batch["response_mask"] = compute_response_mask(batch)
296
+
297
+ # compute global_valid tokens
298
+ batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
299
+
300
+ with _timer("reward", timing_raw):
301
+ # compute reward model score
302
+ if self.use_rm:
303
+ reward_tensor = self.rm_wg.compute_rm_score(batch)
304
+ batch = batch.union(reward_tensor)
305
+
306
+ reward_extra_infos_dict = {}
307
+
308
+ # for agent mode, pad the lengths to calculate old log prob, ref, and values
309
+ batch, pad_size = pad_dataproto_to_divisor(batch, self.actor_rollout_wg.world_size)
310
+
311
+ # recompute old_log_probs
312
+ with _timer("old_log_prob", timing_raw):
313
+ old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
314
+ entropys = old_log_prob.batch["entropys"]
315
+ response_masks = batch.batch["response_mask"]
316
+ loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
317
+ entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
318
+ old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()}
319
+ metrics.update(old_log_prob_metrics)
320
+ old_log_prob.batch.pop("entropys")
321
+ batch = batch.union(old_log_prob)
322
+
323
+ if self.use_reference_policy:
324
+ # compute reference log_prob
325
+ with _timer("ref", timing_raw):
326
+ ref_log_prob = self._compute_reference_log_prob(batch)
327
+ batch = batch.union(ref_log_prob)
328
+
329
+ # compute values
330
+ if self.use_critic:
331
+ with _timer("values", timing_raw):
332
+ values = self.critic_wg.compute_values(batch)
333
+ batch = batch.union(values)
334
+
335
+ # for agent mode, unpad to calculate adv
336
+ # it is important, as adv should be based on the raw traces
337
+ batch = unpad_dataproto(batch, pad_size=pad_size)
338
+
339
+ with _timer("adv", timing_raw):
340
+ # if agent_mode is enabled, there is already token_level_scores
341
+ # token_level_scores is not needed to compute here
342
+
343
+ # compute rewards. apply_kl_penalty if available
344
+ if self.config.algorithm.use_kl_in_reward:
345
+ batch, kl_metrics = apply_kl_penalty(
346
+ batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
347
+ )
348
+ metrics.update(kl_metrics)
349
+ else:
350
+ batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
351
+
352
+ # compute advantages, executed on the driver process
353
+
354
+ norm_adv_by_std_in_grpo = self.config.algorithm.get(
355
+ "norm_adv_by_std_in_grpo", True
356
+ ) # GRPO adv normalization factor
357
+
358
+ batch = compute_advantage(
359
+ batch,
360
+ adv_estimator=self.config.algorithm.adv_estimator,
361
+ gamma=self.config.algorithm.gamma,
362
+ lam=self.config.algorithm.lam,
363
+ num_repeat=self.config.actor_rollout_ref.rollout.n,
364
+ norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
365
+ config=self.config.algorithm,
366
+ )
367
+
368
+ # Calculate the metrics before processing. Refer to the comments of function `compute_data_metrics` for details.
369
+ metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_before_processing"))
370
+
371
+ # after advantages are assinged, we begin to drop (1) long prompt (2) floor to ppo minisize
372
+ keep_indices = (~batch.batch["is_drop_mask"]).nonzero(as_tuple=True)[0]
373
+ metrics["training/n_triplets_prompt_too_long"] = (
374
+ batch.batch["is_drop_mask"].shape[0] - keep_indices.shape[0]
375
+ )
376
+ batch = batch[keep_indices]
377
+ # next, round to minibatch size
378
+ mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
379
+ n_transition = len(batch)
380
+ random_indices = list(range(n_transition))
381
+ random.shuffle(random_indices)
382
+ batch.reorder(torch.tensor(random_indices).type(torch.int32))
383
+ n_remained_transition = n_transition // mini_batch_size * mini_batch_size
384
+ batch = batch[list(range(n_remained_transition))]
385
+ metrics["training/n_triplets_dropped_remainder"] = n_transition - n_remained_transition
386
+
387
+ # Agent mode note: Change the order of balance batch;
388
+ # 1. first calculate advantage
389
+ # 2. then drop the samples (too long prompt & floor to ppo minisize)
390
+ # 3. balance
391
+ # balance the number of valid tokens on each dp rank.
392
+ # Note that this breaks the order of data inside the batch.
393
+ # Please take care when you implement group based adv computation such as GRPO and rloo
394
+ if self.config.trainer.balance_batch:
395
+ self._balance_batch(batch, metrics=metrics)
396
+
397
+ # update critic
398
+ if self.use_critic:
399
+ with _timer("update_critic", timing_raw):
400
+ critic_output = self.critic_wg.update_critic(batch)
401
+ critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
402
+ metrics.update(critic_output_metrics)
403
+
404
+ # implement critic warmup
405
+ if self.config.trainer.critic_warmup <= self.global_steps:
406
+ # update actor
407
+ with _timer("update_actor", timing_raw):
408
+ batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
409
+ actor_output = self.actor_rollout_wg.update_actor(batch)
410
+ actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
411
+ metrics.update(actor_output_metrics)
412
+
413
+ # Log rollout generations if enabled
414
+ rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
415
+ if rollout_data_dir:
416
+ with _timer("dump_rollout_generations", timing_raw):
417
+ print(batch.batch.keys())
418
+ inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
419
+ outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
420
+ scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
421
+ self._dump_generations(
422
+ inputs=inputs,
423
+ outputs=outputs,
424
+ scores=scores,
425
+ reward_extra_infos_dict=reward_extra_infos_dict,
426
+ dump_path=rollout_data_dir,
427
+ )
428
+
429
+ # compute training metrics
430
+ metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_after_processing"))
431
+ metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
432
+ # TODO: implement actual tflpo and theoretical tflpo
433
+ n_gpus = self.resource_pool_manager.get_n_gpus()
434
+ metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
435
+
436
+ return metrics
437
+
438
+ def fit(self):
439
+ logger = Tracking(
440
+ project_name=self.config.trainer.project_name,
441
+ experiment_name=self.config.trainer.experiment_name,
442
+ default_backend=self.config.trainer.logger,
443
+ config=OmegaConf.to_container(self.config, resolve=True),
444
+ )
445
+
446
+ self.global_steps = 0
447
+
448
+ # load checkpoint before doing anything
449
+ self._load_checkpoint()
450
+
451
+ assert self.async_rollout_mode, "If agent mode is enabled, async server must be enabled"
452
+ if self.adapter is not None and not isinstance(self.adapter, TraceToTripletBase):
453
+ raise ValueError("Adapter must be a TraceToTripletBase for currently VERL implementation.")
454
+ verl_version = verl.__version__
455
+ if verl_version == "0.5.0":
456
+ # Note (Zhiyuan): To avoid further patch into vllm async server, using the same sentence to get the naming here.
457
+ # However, it is possible that verl updates the naming and causes incompatibility.
458
+ # Reference: https://github.com/volcengine/verl/blob/5b5e09d9cc20625e436d01f69d9cc739ff681c54/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L217
459
+ model = "/".join(self.config.actor_rollout_ref.model.path.split("/")[-2:])
460
+ else:
461
+ # For other versions (e.g., 0.6.0), we use the full path to the model.
462
+ model = self.config.actor_rollout_ref.model.path
463
+ self.agent_mode_daemon = self.daemon_cls(
464
+ self.config.mantisdk.port,
465
+ self.config.actor_rollout_ref.rollout.n,
466
+ train_information={
467
+ "model": model,
468
+ "temperature": self.config.actor_rollout_ref.rollout.temperature,
469
+ },
470
+ tokenizer=self.tokenizer,
471
+ mini_batch_size=self.config.actor_rollout_ref.actor.ppo_mini_batch_size,
472
+ pad_token_id=self.tokenizer.pad_token_id,
473
+ mode="v1" if self.store is not None else "v0",
474
+ store=self.store,
475
+ llm_proxy=self.llm_proxy,
476
+ adapter=self.adapter,
477
+ processor=self.processor, # For Qwen2-VL mrope position_ids
478
+ image_base_dir=getattr(self.config.data, "image_base_dir", None),
479
+ trace_aggregator=self.config.mantisdk.trace_aggregator,
480
+ )
481
+ self.agent_mode_daemon.start()
482
+
483
+ # perform validation before training
484
+ # currently, we only support validation using the reward_function.
485
+ if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
486
+ val_metrics = self._validate()
487
+ assert val_metrics, f"{val_metrics=}"
488
+ pprint(f"Initial validation metrics: {val_metrics}")
489
+ logger.log(data=val_metrics, step=self.global_steps)
490
+ if self.config.trainer.get("val_only", False):
491
+ return
492
+
493
+ # add tqdm
494
+ progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
495
+
496
+ # we start from step 1
497
+ self.global_steps += 1
498
+ last_val_metrics = None
499
+
500
+ for epoch in range(self.config.trainer.total_epochs):
501
+ for batch_dict in self.train_dataloader:
502
+ metrics = {}
503
+ timing_raw = {}
504
+ is_last_step = self.global_steps >= self.total_training_steps
505
+
506
+ # train step
507
+ metrics = self._train_step(batch_dict)
508
+
509
+ # validate
510
+ if (
511
+ self.val_reward_fn is not None
512
+ and self.config.trainer.test_freq > 0
513
+ and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
514
+ ):
515
+ with _timer("validate", timing_raw):
516
+ val_metrics: dict = self._validate()
517
+ if is_last_step:
518
+ last_val_metrics = val_metrics
519
+ metrics.update(val_metrics)
520
+
521
+ if self.config.trainer.save_freq > 0 and (
522
+ is_last_step or self.global_steps % self.config.trainer.save_freq == 0
523
+ ):
524
+ with _timer("save_checkpoint", timing_raw):
525
+ self._save_checkpoint()
526
+
527
+ # step metrics
528
+ metrics.update(
529
+ {
530
+ "training/global_step": self.global_steps,
531
+ "training/epoch": epoch,
532
+ }
533
+ )
534
+
535
+ # TODO: make a canonical logger that supports various backend
536
+ logger.log(data=metrics, step=self.global_steps)
537
+
538
+ if is_last_step:
539
+ pprint(f"Final validation metrics: {last_val_metrics}")
540
+ progress_bar.close()
541
+
542
+ # This exit logic is to ensure a robust CI.
543
+ pprint(f"Flush the logger...")
544
+ del logger # Make sure the loggers are flushed and closed properly
545
+ pprint(f"Training finished at step {self.global_steps}.")
546
+ return
547
+
548
+ progress_bar.update(1)
549
+ self.global_steps += 1
@@ -0,0 +1,119 @@
1
+ Metadata-Version: 2.4
2
+ Name: mantisdk
3
+ Version: 0.1.0
4
+ Summary: Mantisdk - AI Agent Training and Evaluation Platform
5
+ Project-URL: Homepage, https://github.com/withmetis/mantis
6
+ Project-URL: Documentation, https://withmetis.github.io/mantis/mantisdk/
7
+ Project-URL: Repository, https://github.com/withmetis/mantis
8
+ Project-URL: Issues, https://github.com/withmetis/mantis/issues
9
+ Author-email: Metis Team <team@withmetis.ai>
10
+ License: MIT
11
+ License-File: LICENSE
12
+ Keywords: agents,ai,evaluation,llm,mantis,observability,reinforcement-learning,training
13
+ Classifier: Development Status :: 4 - Beta
14
+ Classifier: Intended Audience :: Developers
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.10
18
+ Classifier: Programming Language :: Python :: 3.11
19
+ Classifier: Programming Language :: Python :: 3.12
20
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
+ Requires-Python: >=3.10
22
+ Requires-Dist: agentops>=0.4.13
23
+ Requires-Dist: aiohttp
24
+ Requires-Dist: aiologic
25
+ Requires-Dist: fastapi
26
+ Requires-Dist: flask
27
+ Requires-Dist: gepa>=0.0.24
28
+ Requires-Dist: gpustat
29
+ Requires-Dist: graphviz
30
+ Requires-Dist: gunicorn
31
+ Requires-Dist: litellm[proxy]>=1.74
32
+ Requires-Dist: openai
33
+ Requires-Dist: opentelemetry-api>=1.35
34
+ Requires-Dist: opentelemetry-exporter-otlp>=1.35
35
+ Requires-Dist: opentelemetry-sdk>=1.35
36
+ Requires-Dist: portpicker
37
+ Requires-Dist: psutil
38
+ Requires-Dist: pydantic>=2.11
39
+ Requires-Dist: rich
40
+ Requires-Dist: setproctitle
41
+ Requires-Dist: uvicorn
42
+ Requires-Dist: uvicorn-worker
43
+ Provides-Extra: apo
44
+ Requires-Dist: poml; extra == 'apo'
45
+ Provides-Extra: mongo
46
+ Requires-Dist: pymongo; extra == 'mongo'
47
+ Provides-Extra: verl
48
+ Requires-Dist: verl>=0.5.0; extra == 'verl'
49
+ Requires-Dist: vllm>=0.8.4; extra == 'verl'
50
+ Provides-Extra: weave
51
+ Requires-Dist: weave>=0.52.22; extra == 'weave'
52
+ Description-Content-Type: text/markdown
53
+
54
+ # Mantisdk
55
+
56
+ [![PyPI version](https://badge.fury.io/py/mantisdk.svg)](https://badge.fury.io/py/mantisdk)
57
+ [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
58
+
59
+ **AI Agent Training and Evaluation Platform**
60
+
61
+ Mantisdk is a comprehensive toolkit for training and evaluating AI agents using reinforcement learning, automatic prompt optimization, and supervised fine-tuning.
62
+
63
+ ## Core Features
64
+
65
+ - Turn your agent into an optimizable beast with **minimal code changes**
66
+ - Build with **any** agent framework (LangChain, OpenAI Agent SDK, AutoGen, CrewAI, and more)
67
+ - **Selectively** optimize one or more agents in a multi-agent system
68
+ - Embraces **algorithms** like Reinforcement Learning, Automatic Prompt Optimization, Supervised Fine-tuning and more
69
+
70
+ ## Installation
71
+
72
+ ```bash
73
+ pip install mantisdk
74
+ ```
75
+
76
+ For optional dependencies:
77
+
78
+ ```bash
79
+ # For APO (Automatic Prompt Optimization)
80
+ pip install mantisdk[apo]
81
+
82
+ # For VERL integration
83
+ pip install mantisdk[verl]
84
+
85
+ # For Weave integration
86
+ pip install mantisdk[weave]
87
+
88
+ # For MongoDB store
89
+ pip install mantisdk[mongo]
90
+ ```
91
+
92
+ ## Quick Start
93
+
94
+ ```python
95
+ import mantisdk as msk
96
+
97
+ # Initialize the client
98
+ client = msk.MantisdkClient()
99
+
100
+ # Your agent code here...
101
+ ```
102
+
103
+ ## CLI Usage
104
+
105
+ ```bash
106
+ # Start the Mantisdk server
107
+ msk store serve
108
+
109
+ # Run with vLLM
110
+ msk vllm start
111
+ ```
112
+
113
+ ## Documentation
114
+
115
+ For full documentation, visit [https://withmetis.github.io/mantis/mantisdk/](https://withmetis.github.io/mantis/mantisdk/)
116
+
117
+ ## License
118
+
119
+ MIT License - see [LICENSE](LICENSE) for details.