mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
mantisdk/verl/trainer.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# type: ignore
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import random
|
|
8
|
+
from contextlib import contextmanager
|
|
9
|
+
from copy import deepcopy
|
|
10
|
+
from pprint import pprint
|
|
11
|
+
from typing import Dict, Tuple, Type
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
import verl
|
|
16
|
+
from codetiming import Timer
|
|
17
|
+
from omegaconf import OmegaConf
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
from verl import DataProto
|
|
20
|
+
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
|
|
21
|
+
from verl.trainer.ppo.core_algos import agg_loss
|
|
22
|
+
from verl.trainer.ppo.metric_utils import (
|
|
23
|
+
_compute_response_info,
|
|
24
|
+
compute_throughout_metrics,
|
|
25
|
+
compute_timing_metrics,
|
|
26
|
+
)
|
|
27
|
+
from verl.trainer.ppo.ray_trainer import (
|
|
28
|
+
AdvantageEstimator,
|
|
29
|
+
RayPPOTrainer,
|
|
30
|
+
apply_kl_penalty,
|
|
31
|
+
compute_advantage,
|
|
32
|
+
compute_response_mask,
|
|
33
|
+
)
|
|
34
|
+
from verl.utils.metric import reduce_metrics
|
|
35
|
+
from verl.utils.tracking import Tracking
|
|
36
|
+
|
|
37
|
+
from mantisdk.adapter import TraceAdapter, TraceToTripletBase
|
|
38
|
+
from mantisdk.llm_proxy import LLMProxy
|
|
39
|
+
from mantisdk.store.base import LightningStore
|
|
40
|
+
|
|
41
|
+
from .daemon import AgentModeDaemon
|
|
42
|
+
|
|
43
|
+
__all__ = [
|
|
44
|
+
"MantisdkTrainer",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@contextmanager
|
|
49
|
+
def _timer(name: str, timing_raw: Dict[str, float]):
|
|
50
|
+
with Timer(name=name, logger=None) as timer:
|
|
51
|
+
yield
|
|
52
|
+
if name not in timing_raw:
|
|
53
|
+
timing_raw[name] = 0
|
|
54
|
+
timing_raw[name] += timer.last
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# This function is adapted from verl.
|
|
58
|
+
# We introduce a new parameter `suffix` to distinguish between metrics computed
|
|
59
|
+
# before and after Mantisdk’s post-processing.
|
|
60
|
+
# - "Before" refers to raw reward and advantage values.
|
|
61
|
+
# - "After" refers to values computed following post-processing, which involves:
|
|
62
|
+
# (1) Dropping prompts that exceed the maximum allowed length.
|
|
63
|
+
# (2) Adjusting the batch size to be a multiple of the mini PPO size.
|
|
64
|
+
# Different suffixes are used to label these two stages accordingly.
|
|
65
|
+
def compute_data_metrics(batch: DataProto, use_critic: bool = True, suffix: str = "") -> Dict[str, Any]:
|
|
66
|
+
"""
|
|
67
|
+
Computes various metrics from a batch of data for PPO training.
|
|
68
|
+
|
|
69
|
+
This function calculates metrics related to scores, rewards, advantages, returns, values,
|
|
70
|
+
and sequence lengths from a batch of data. It provides statistical information (mean, max, min)
|
|
71
|
+
for each metric category.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
|
|
75
|
+
use_critic: Whether to include critic-specific metrics. Defaults to True.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
A dictionary of metrics including:
|
|
79
|
+
- critic/score/mean, max, min: Statistics about sequence scores
|
|
80
|
+
- critic/rewards/mean, max, min: Statistics about sequence rewards
|
|
81
|
+
- critic/advantages/mean, max, min: Statistics about advantages
|
|
82
|
+
- critic/returns/mean, max, min: Statistics about returns
|
|
83
|
+
- critic/values/mean, max, min: Statistics about critic values (if use_critic=True)
|
|
84
|
+
- critic/vf_explained_var: Explained variance of the value function (if use_critic=True)
|
|
85
|
+
- response_length/mean, max, min, clip_ratio: Statistics about response lengths
|
|
86
|
+
- prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
|
|
87
|
+
"""
|
|
88
|
+
sequence_score = batch.batch["token_level_scores"].sum(-1)
|
|
89
|
+
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
|
|
90
|
+
|
|
91
|
+
advantages = batch.batch["advantages"]
|
|
92
|
+
returns = batch.batch["returns"]
|
|
93
|
+
|
|
94
|
+
max_response_length = batch.batch["responses"].shape[-1]
|
|
95
|
+
|
|
96
|
+
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
|
|
97
|
+
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
|
|
98
|
+
|
|
99
|
+
max_prompt_length = prompt_mask.size(-1)
|
|
100
|
+
|
|
101
|
+
response_info = _compute_response_info(batch)
|
|
102
|
+
prompt_length = response_info["prompt_length"]
|
|
103
|
+
response_length = response_info["response_length"]
|
|
104
|
+
|
|
105
|
+
valid_adv = torch.masked_select(advantages, response_mask)
|
|
106
|
+
valid_returns = torch.masked_select(returns, response_mask)
|
|
107
|
+
|
|
108
|
+
if use_critic:
|
|
109
|
+
values = batch.batch["values"]
|
|
110
|
+
valid_values = torch.masked_select(values, response_mask)
|
|
111
|
+
return_diff_var = torch.var(valid_returns - valid_values)
|
|
112
|
+
return_var = torch.var(valid_returns)
|
|
113
|
+
|
|
114
|
+
metrics = {
|
|
115
|
+
# score
|
|
116
|
+
"critic/score/mean" + suffix: torch.mean(sequence_score).detach().item(),
|
|
117
|
+
"critic/score/max" + suffix: torch.max(sequence_score).detach().item(),
|
|
118
|
+
"critic/score/min" + suffix: torch.min(sequence_score).detach().item(),
|
|
119
|
+
# reward
|
|
120
|
+
"critic/rewards/mean" + suffix: torch.mean(sequence_reward).detach().item(),
|
|
121
|
+
"critic/rewards/max" + suffix: torch.max(sequence_reward).detach().item(),
|
|
122
|
+
"critic/rewards/min" + suffix: torch.min(sequence_reward).detach().item(),
|
|
123
|
+
# adv
|
|
124
|
+
"critic/advantages/mean" + suffix: torch.mean(valid_adv).detach().item(),
|
|
125
|
+
"critic/advantages/max" + suffix: torch.max(valid_adv).detach().item(),
|
|
126
|
+
"critic/advantages/min" + suffix: torch.min(valid_adv).detach().item(),
|
|
127
|
+
# returns
|
|
128
|
+
"critic/returns/mean" + suffix: torch.mean(valid_returns).detach().item(),
|
|
129
|
+
"critic/returns/max" + suffix: torch.max(valid_returns).detach().item(),
|
|
130
|
+
"critic/returns/min" + suffix: torch.min(valid_returns).detach().item(),
|
|
131
|
+
**(
|
|
132
|
+
{
|
|
133
|
+
# values
|
|
134
|
+
"critic/values/mean" + suffix: torch.mean(valid_values).detach().item(),
|
|
135
|
+
"critic/values/max" + suffix: torch.max(valid_values).detach().item(),
|
|
136
|
+
"critic/values/min" + suffix: torch.min(valid_values).detach().item(),
|
|
137
|
+
# vf explained var
|
|
138
|
+
"critic/vf_explained_var" + suffix: (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
|
|
139
|
+
}
|
|
140
|
+
if use_critic
|
|
141
|
+
else {}
|
|
142
|
+
),
|
|
143
|
+
# response length
|
|
144
|
+
"response_length/mean" + suffix: torch.mean(response_length).detach().item(),
|
|
145
|
+
"response_length/max" + suffix: torch.max(response_length).detach().item(),
|
|
146
|
+
"response_length/min" + suffix: torch.min(response_length).detach().item(),
|
|
147
|
+
"response_length/clip_ratio"
|
|
148
|
+
+ suffix: torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),
|
|
149
|
+
# prompt length
|
|
150
|
+
"prompt_length/mean" + suffix: torch.mean(prompt_length).detach().item(),
|
|
151
|
+
"prompt_length/max" + suffix: torch.max(prompt_length).detach().item(),
|
|
152
|
+
"prompt_length/min" + suffix: torch.min(prompt_length).detach().item(),
|
|
153
|
+
"prompt_length/clip_ratio"
|
|
154
|
+
+ suffix: torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
|
|
155
|
+
}
|
|
156
|
+
return metrics
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class MantisdkTrainer(RayPPOTrainer):
|
|
160
|
+
"""
|
|
161
|
+
Specialized PPO trainer for agent-based reinforcement learning.
|
|
162
|
+
|
|
163
|
+
This trainer is designed specifically for scenarios where the model interacts with
|
|
164
|
+
external environments, tools, or APIs through an MantisdkServer. It simplifies
|
|
165
|
+
the training loop by removing the complex conditional logic present in the original
|
|
166
|
+
RayPPOTrainer and focusing on the agent mode workflow.
|
|
167
|
+
|
|
168
|
+
Key differences from RayPPOTrainer:
|
|
169
|
+
|
|
170
|
+
1. Uses AgentModeDaemon for server communication
|
|
171
|
+
2. Simplified data flow without pop/union operations
|
|
172
|
+
3. Direct batch processing through agent daemon
|
|
173
|
+
4. Streamlined validation using agent_mode validation
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
store: LightningStore | None,
|
|
179
|
+
llm_proxy: LLMProxy | None,
|
|
180
|
+
adapter: TraceAdapter | None,
|
|
181
|
+
daemon_cls: Type[AgentModeDaemon],
|
|
182
|
+
**kwargs,
|
|
183
|
+
):
|
|
184
|
+
super().__init__(**kwargs)
|
|
185
|
+
self.store = store
|
|
186
|
+
self.llm_proxy = llm_proxy
|
|
187
|
+
self.adapter = adapter
|
|
188
|
+
self.daemon_cls = daemon_cls
|
|
189
|
+
|
|
190
|
+
def _validate(self):
|
|
191
|
+
assert len(self.val_dataloader) == 1, "Please set val_batch_size to None for better throughput."
|
|
192
|
+
|
|
193
|
+
test_data = next(iter(self.val_dataloader))
|
|
194
|
+
test_batch = DataProto.from_single_dict(test_data)
|
|
195
|
+
|
|
196
|
+
self.async_rollout_manager.wake_up()
|
|
197
|
+
self.agent_mode_daemon.set_up_data_and_server(
|
|
198
|
+
test_batch.non_tensor_batch,
|
|
199
|
+
self.async_rollout_manager.server_addresses,
|
|
200
|
+
is_train=False,
|
|
201
|
+
)
|
|
202
|
+
self.agent_mode_daemon.run_until_all_finished()
|
|
203
|
+
test_metrics = self.agent_mode_daemon.get_test_metrics()
|
|
204
|
+
self.agent_mode_daemon.clear_data_and_server()
|
|
205
|
+
self.async_rollout_manager.sleep()
|
|
206
|
+
return test_metrics
|
|
207
|
+
|
|
208
|
+
def _compute_reference_log_prob(self, batch: DataProto) -> DataProto:
|
|
209
|
+
"""Compute reference log probability using the correct worker based on LoRA configuration.
|
|
210
|
+
|
|
211
|
+
In verl 0.6.0+, when LoRA is detected (indicated by ref_in_actor=True),
|
|
212
|
+
the reference policy is computed by the actor rollout worker instead of a separate
|
|
213
|
+
ref policy worker. This method handles both scenarios by checking the ref_in_actor flag.
|
|
214
|
+
Note: verl sets ref_in_actor=True when it detects LoRA configuration (e.g., lora_rank > 0 or lora_adapter_path is set).
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
batch: The data batch to compute reference log probabilities for.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
DataProto with reference log probabilities added.
|
|
221
|
+
|
|
222
|
+
Raises:
|
|
223
|
+
RuntimeError: If the required worker is not available.
|
|
224
|
+
"""
|
|
225
|
+
if getattr(self, "ref_in_actor", False):
|
|
226
|
+
actor_worker = getattr(self, "actor_rollout_wg", None)
|
|
227
|
+
if actor_worker is None:
|
|
228
|
+
raise RuntimeError("actor_rollout_wg is required when ref_in_actor is True.")
|
|
229
|
+
return actor_worker.compute_ref_log_prob(batch)
|
|
230
|
+
|
|
231
|
+
ref_worker = getattr(self, "ref_policy_wg", None)
|
|
232
|
+
if ref_worker is None:
|
|
233
|
+
raise RuntimeError(
|
|
234
|
+
"Reference policy worker was not initialized. "
|
|
235
|
+
"Ensure `use_reference_policy` is enabled and the VERL config exposes the ref worker."
|
|
236
|
+
)
|
|
237
|
+
return ref_worker.compute_ref_log_prob(batch)
|
|
238
|
+
|
|
239
|
+
def _train_step(self, batch_dict: dict) -> dict:
|
|
240
|
+
# Isolate in a separate method to automatically recycle the variables before validation.
|
|
241
|
+
batch: DataProto = DataProto.from_single_dict(batch_dict)
|
|
242
|
+
metrics = {}
|
|
243
|
+
timing_raw = {}
|
|
244
|
+
|
|
245
|
+
with _timer("step", timing_raw):
|
|
246
|
+
|
|
247
|
+
# When agent mode is enabled, we read the batch as it is.
|
|
248
|
+
gen_batch = batch
|
|
249
|
+
|
|
250
|
+
# generate a batch
|
|
251
|
+
with _timer("gen", timing_raw):
|
|
252
|
+
self.async_rollout_manager.wake_up()
|
|
253
|
+
self.agent_mode_daemon.set_up_data_and_server(
|
|
254
|
+
gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses
|
|
255
|
+
)
|
|
256
|
+
self.agent_mode_daemon.run_until_all_finished()
|
|
257
|
+
batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch(
|
|
258
|
+
max_prompt_length=(
|
|
259
|
+
self.config.mantisdk.trace_aggregator.trajectory_max_prompt_length
|
|
260
|
+
if self.config.mantisdk.trace_aggregator.level.startswith("trajectory")
|
|
261
|
+
else self.config.data.max_prompt_length
|
|
262
|
+
),
|
|
263
|
+
max_response_length=(
|
|
264
|
+
self.config.mantisdk.trace_aggregator.trajectory_max_response_length
|
|
265
|
+
if self.config.mantisdk.trace_aggregator.level.startswith("trajectory")
|
|
266
|
+
else self.config.data.max_response_length
|
|
267
|
+
),
|
|
268
|
+
device=gen_batch.batch["fake_ids"].device,
|
|
269
|
+
global_steps=self.global_steps,
|
|
270
|
+
)
|
|
271
|
+
metrics.update(agent_metrics)
|
|
272
|
+
self.agent_mode_daemon.clear_data_and_server()
|
|
273
|
+
self.async_rollout_manager.sleep()
|
|
274
|
+
|
|
275
|
+
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
|
|
276
|
+
with _timer("gen_max", timing_raw):
|
|
277
|
+
gen_baseline_batch = deepcopy(gen_batch)
|
|
278
|
+
gen_baseline_batch.meta_info["do_sample"] = False
|
|
279
|
+
gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
|
|
280
|
+
|
|
281
|
+
batch = batch.union(gen_baseline_output)
|
|
282
|
+
reward_baseline_tensor = self.reward_fn(batch)
|
|
283
|
+
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
|
|
284
|
+
|
|
285
|
+
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
|
|
286
|
+
|
|
287
|
+
batch.batch["reward_baselines"] = reward_baseline_tensor
|
|
288
|
+
|
|
289
|
+
del gen_baseline_batch, gen_baseline_output
|
|
290
|
+
|
|
291
|
+
# uid is used for algorithm like GRPO, should be aligned to data id
|
|
292
|
+
batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"]
|
|
293
|
+
|
|
294
|
+
if "response_mask" not in batch.batch:
|
|
295
|
+
batch.batch["response_mask"] = compute_response_mask(batch)
|
|
296
|
+
|
|
297
|
+
# compute global_valid tokens
|
|
298
|
+
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
|
|
299
|
+
|
|
300
|
+
with _timer("reward", timing_raw):
|
|
301
|
+
# compute reward model score
|
|
302
|
+
if self.use_rm:
|
|
303
|
+
reward_tensor = self.rm_wg.compute_rm_score(batch)
|
|
304
|
+
batch = batch.union(reward_tensor)
|
|
305
|
+
|
|
306
|
+
reward_extra_infos_dict = {}
|
|
307
|
+
|
|
308
|
+
# for agent mode, pad the lengths to calculate old log prob, ref, and values
|
|
309
|
+
batch, pad_size = pad_dataproto_to_divisor(batch, self.actor_rollout_wg.world_size)
|
|
310
|
+
|
|
311
|
+
# recompute old_log_probs
|
|
312
|
+
with _timer("old_log_prob", timing_raw):
|
|
313
|
+
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
|
|
314
|
+
entropys = old_log_prob.batch["entropys"]
|
|
315
|
+
response_masks = batch.batch["response_mask"]
|
|
316
|
+
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
|
|
317
|
+
entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
|
|
318
|
+
old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()}
|
|
319
|
+
metrics.update(old_log_prob_metrics)
|
|
320
|
+
old_log_prob.batch.pop("entropys")
|
|
321
|
+
batch = batch.union(old_log_prob)
|
|
322
|
+
|
|
323
|
+
if self.use_reference_policy:
|
|
324
|
+
# compute reference log_prob
|
|
325
|
+
with _timer("ref", timing_raw):
|
|
326
|
+
ref_log_prob = self._compute_reference_log_prob(batch)
|
|
327
|
+
batch = batch.union(ref_log_prob)
|
|
328
|
+
|
|
329
|
+
# compute values
|
|
330
|
+
if self.use_critic:
|
|
331
|
+
with _timer("values", timing_raw):
|
|
332
|
+
values = self.critic_wg.compute_values(batch)
|
|
333
|
+
batch = batch.union(values)
|
|
334
|
+
|
|
335
|
+
# for agent mode, unpad to calculate adv
|
|
336
|
+
# it is important, as adv should be based on the raw traces
|
|
337
|
+
batch = unpad_dataproto(batch, pad_size=pad_size)
|
|
338
|
+
|
|
339
|
+
with _timer("adv", timing_raw):
|
|
340
|
+
# if agent_mode is enabled, there is already token_level_scores
|
|
341
|
+
# token_level_scores is not needed to compute here
|
|
342
|
+
|
|
343
|
+
# compute rewards. apply_kl_penalty if available
|
|
344
|
+
if self.config.algorithm.use_kl_in_reward:
|
|
345
|
+
batch, kl_metrics = apply_kl_penalty(
|
|
346
|
+
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
|
|
347
|
+
)
|
|
348
|
+
metrics.update(kl_metrics)
|
|
349
|
+
else:
|
|
350
|
+
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
|
|
351
|
+
|
|
352
|
+
# compute advantages, executed on the driver process
|
|
353
|
+
|
|
354
|
+
norm_adv_by_std_in_grpo = self.config.algorithm.get(
|
|
355
|
+
"norm_adv_by_std_in_grpo", True
|
|
356
|
+
) # GRPO adv normalization factor
|
|
357
|
+
|
|
358
|
+
batch = compute_advantage(
|
|
359
|
+
batch,
|
|
360
|
+
adv_estimator=self.config.algorithm.adv_estimator,
|
|
361
|
+
gamma=self.config.algorithm.gamma,
|
|
362
|
+
lam=self.config.algorithm.lam,
|
|
363
|
+
num_repeat=self.config.actor_rollout_ref.rollout.n,
|
|
364
|
+
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
|
|
365
|
+
config=self.config.algorithm,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# Calculate the metrics before processing. Refer to the comments of function `compute_data_metrics` for details.
|
|
369
|
+
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_before_processing"))
|
|
370
|
+
|
|
371
|
+
# after advantages are assinged, we begin to drop (1) long prompt (2) floor to ppo minisize
|
|
372
|
+
keep_indices = (~batch.batch["is_drop_mask"]).nonzero(as_tuple=True)[0]
|
|
373
|
+
metrics["training/n_triplets_prompt_too_long"] = (
|
|
374
|
+
batch.batch["is_drop_mask"].shape[0] - keep_indices.shape[0]
|
|
375
|
+
)
|
|
376
|
+
batch = batch[keep_indices]
|
|
377
|
+
# next, round to minibatch size
|
|
378
|
+
mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
|
|
379
|
+
n_transition = len(batch)
|
|
380
|
+
random_indices = list(range(n_transition))
|
|
381
|
+
random.shuffle(random_indices)
|
|
382
|
+
batch.reorder(torch.tensor(random_indices).type(torch.int32))
|
|
383
|
+
n_remained_transition = n_transition // mini_batch_size * mini_batch_size
|
|
384
|
+
batch = batch[list(range(n_remained_transition))]
|
|
385
|
+
metrics["training/n_triplets_dropped_remainder"] = n_transition - n_remained_transition
|
|
386
|
+
|
|
387
|
+
# Agent mode note: Change the order of balance batch;
|
|
388
|
+
# 1. first calculate advantage
|
|
389
|
+
# 2. then drop the samples (too long prompt & floor to ppo minisize)
|
|
390
|
+
# 3. balance
|
|
391
|
+
# balance the number of valid tokens on each dp rank.
|
|
392
|
+
# Note that this breaks the order of data inside the batch.
|
|
393
|
+
# Please take care when you implement group based adv computation such as GRPO and rloo
|
|
394
|
+
if self.config.trainer.balance_batch:
|
|
395
|
+
self._balance_batch(batch, metrics=metrics)
|
|
396
|
+
|
|
397
|
+
# update critic
|
|
398
|
+
if self.use_critic:
|
|
399
|
+
with _timer("update_critic", timing_raw):
|
|
400
|
+
critic_output = self.critic_wg.update_critic(batch)
|
|
401
|
+
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
|
|
402
|
+
metrics.update(critic_output_metrics)
|
|
403
|
+
|
|
404
|
+
# implement critic warmup
|
|
405
|
+
if self.config.trainer.critic_warmup <= self.global_steps:
|
|
406
|
+
# update actor
|
|
407
|
+
with _timer("update_actor", timing_raw):
|
|
408
|
+
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
|
|
409
|
+
actor_output = self.actor_rollout_wg.update_actor(batch)
|
|
410
|
+
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
|
|
411
|
+
metrics.update(actor_output_metrics)
|
|
412
|
+
|
|
413
|
+
# Log rollout generations if enabled
|
|
414
|
+
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
|
|
415
|
+
if rollout_data_dir:
|
|
416
|
+
with _timer("dump_rollout_generations", timing_raw):
|
|
417
|
+
print(batch.batch.keys())
|
|
418
|
+
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
|
|
419
|
+
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
|
|
420
|
+
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
|
|
421
|
+
self._dump_generations(
|
|
422
|
+
inputs=inputs,
|
|
423
|
+
outputs=outputs,
|
|
424
|
+
scores=scores,
|
|
425
|
+
reward_extra_infos_dict=reward_extra_infos_dict,
|
|
426
|
+
dump_path=rollout_data_dir,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# compute training metrics
|
|
430
|
+
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_after_processing"))
|
|
431
|
+
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
|
|
432
|
+
# TODO: implement actual tflpo and theoretical tflpo
|
|
433
|
+
n_gpus = self.resource_pool_manager.get_n_gpus()
|
|
434
|
+
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
|
|
435
|
+
|
|
436
|
+
return metrics
|
|
437
|
+
|
|
438
|
+
def fit(self):
|
|
439
|
+
logger = Tracking(
|
|
440
|
+
project_name=self.config.trainer.project_name,
|
|
441
|
+
experiment_name=self.config.trainer.experiment_name,
|
|
442
|
+
default_backend=self.config.trainer.logger,
|
|
443
|
+
config=OmegaConf.to_container(self.config, resolve=True),
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
self.global_steps = 0
|
|
447
|
+
|
|
448
|
+
# load checkpoint before doing anything
|
|
449
|
+
self._load_checkpoint()
|
|
450
|
+
|
|
451
|
+
assert self.async_rollout_mode, "If agent mode is enabled, async server must be enabled"
|
|
452
|
+
if self.adapter is not None and not isinstance(self.adapter, TraceToTripletBase):
|
|
453
|
+
raise ValueError("Adapter must be a TraceToTripletBase for currently VERL implementation.")
|
|
454
|
+
verl_version = verl.__version__
|
|
455
|
+
if verl_version == "0.5.0":
|
|
456
|
+
# Note (Zhiyuan): To avoid further patch into vllm async server, using the same sentence to get the naming here.
|
|
457
|
+
# However, it is possible that verl updates the naming and causes incompatibility.
|
|
458
|
+
# Reference: https://github.com/volcengine/verl/blob/5b5e09d9cc20625e436d01f69d9cc739ff681c54/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L217
|
|
459
|
+
model = "/".join(self.config.actor_rollout_ref.model.path.split("/")[-2:])
|
|
460
|
+
else:
|
|
461
|
+
# For other versions (e.g., 0.6.0), we use the full path to the model.
|
|
462
|
+
model = self.config.actor_rollout_ref.model.path
|
|
463
|
+
self.agent_mode_daemon = self.daemon_cls(
|
|
464
|
+
self.config.mantisdk.port,
|
|
465
|
+
self.config.actor_rollout_ref.rollout.n,
|
|
466
|
+
train_information={
|
|
467
|
+
"model": model,
|
|
468
|
+
"temperature": self.config.actor_rollout_ref.rollout.temperature,
|
|
469
|
+
},
|
|
470
|
+
tokenizer=self.tokenizer,
|
|
471
|
+
mini_batch_size=self.config.actor_rollout_ref.actor.ppo_mini_batch_size,
|
|
472
|
+
pad_token_id=self.tokenizer.pad_token_id,
|
|
473
|
+
mode="v1" if self.store is not None else "v0",
|
|
474
|
+
store=self.store,
|
|
475
|
+
llm_proxy=self.llm_proxy,
|
|
476
|
+
adapter=self.adapter,
|
|
477
|
+
processor=self.processor, # For Qwen2-VL mrope position_ids
|
|
478
|
+
image_base_dir=getattr(self.config.data, "image_base_dir", None),
|
|
479
|
+
trace_aggregator=self.config.mantisdk.trace_aggregator,
|
|
480
|
+
)
|
|
481
|
+
self.agent_mode_daemon.start()
|
|
482
|
+
|
|
483
|
+
# perform validation before training
|
|
484
|
+
# currently, we only support validation using the reward_function.
|
|
485
|
+
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
|
|
486
|
+
val_metrics = self._validate()
|
|
487
|
+
assert val_metrics, f"{val_metrics=}"
|
|
488
|
+
pprint(f"Initial validation metrics: {val_metrics}")
|
|
489
|
+
logger.log(data=val_metrics, step=self.global_steps)
|
|
490
|
+
if self.config.trainer.get("val_only", False):
|
|
491
|
+
return
|
|
492
|
+
|
|
493
|
+
# add tqdm
|
|
494
|
+
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
|
|
495
|
+
|
|
496
|
+
# we start from step 1
|
|
497
|
+
self.global_steps += 1
|
|
498
|
+
last_val_metrics = None
|
|
499
|
+
|
|
500
|
+
for epoch in range(self.config.trainer.total_epochs):
|
|
501
|
+
for batch_dict in self.train_dataloader:
|
|
502
|
+
metrics = {}
|
|
503
|
+
timing_raw = {}
|
|
504
|
+
is_last_step = self.global_steps >= self.total_training_steps
|
|
505
|
+
|
|
506
|
+
# train step
|
|
507
|
+
metrics = self._train_step(batch_dict)
|
|
508
|
+
|
|
509
|
+
# validate
|
|
510
|
+
if (
|
|
511
|
+
self.val_reward_fn is not None
|
|
512
|
+
and self.config.trainer.test_freq > 0
|
|
513
|
+
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
|
|
514
|
+
):
|
|
515
|
+
with _timer("validate", timing_raw):
|
|
516
|
+
val_metrics: dict = self._validate()
|
|
517
|
+
if is_last_step:
|
|
518
|
+
last_val_metrics = val_metrics
|
|
519
|
+
metrics.update(val_metrics)
|
|
520
|
+
|
|
521
|
+
if self.config.trainer.save_freq > 0 and (
|
|
522
|
+
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
|
|
523
|
+
):
|
|
524
|
+
with _timer("save_checkpoint", timing_raw):
|
|
525
|
+
self._save_checkpoint()
|
|
526
|
+
|
|
527
|
+
# step metrics
|
|
528
|
+
metrics.update(
|
|
529
|
+
{
|
|
530
|
+
"training/global_step": self.global_steps,
|
|
531
|
+
"training/epoch": epoch,
|
|
532
|
+
}
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# TODO: make a canonical logger that supports various backend
|
|
536
|
+
logger.log(data=metrics, step=self.global_steps)
|
|
537
|
+
|
|
538
|
+
if is_last_step:
|
|
539
|
+
pprint(f"Final validation metrics: {last_val_metrics}")
|
|
540
|
+
progress_bar.close()
|
|
541
|
+
|
|
542
|
+
# This exit logic is to ensure a robust CI.
|
|
543
|
+
pprint(f"Flush the logger...")
|
|
544
|
+
del logger # Make sure the loggers are flushed and closed properly
|
|
545
|
+
pprint(f"Training finished at step {self.global_steps}.")
|
|
546
|
+
return
|
|
547
|
+
|
|
548
|
+
progress_bar.update(1)
|
|
549
|
+
self.global_steps += 1
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mantisdk
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Mantisdk - AI Agent Training and Evaluation Platform
|
|
5
|
+
Project-URL: Homepage, https://github.com/withmetis/mantis
|
|
6
|
+
Project-URL: Documentation, https://withmetis.github.io/mantis/mantisdk/
|
|
7
|
+
Project-URL: Repository, https://github.com/withmetis/mantis
|
|
8
|
+
Project-URL: Issues, https://github.com/withmetis/mantis/issues
|
|
9
|
+
Author-email: Metis Team <team@withmetis.ai>
|
|
10
|
+
License: MIT
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Keywords: agents,ai,evaluation,llm,mantis,observability,reinforcement-learning,training
|
|
13
|
+
Classifier: Development Status :: 4 - Beta
|
|
14
|
+
Classifier: Intended Audience :: Developers
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Requires-Python: >=3.10
|
|
22
|
+
Requires-Dist: agentops>=0.4.13
|
|
23
|
+
Requires-Dist: aiohttp
|
|
24
|
+
Requires-Dist: aiologic
|
|
25
|
+
Requires-Dist: fastapi
|
|
26
|
+
Requires-Dist: flask
|
|
27
|
+
Requires-Dist: gepa>=0.0.24
|
|
28
|
+
Requires-Dist: gpustat
|
|
29
|
+
Requires-Dist: graphviz
|
|
30
|
+
Requires-Dist: gunicorn
|
|
31
|
+
Requires-Dist: litellm[proxy]>=1.74
|
|
32
|
+
Requires-Dist: openai
|
|
33
|
+
Requires-Dist: opentelemetry-api>=1.35
|
|
34
|
+
Requires-Dist: opentelemetry-exporter-otlp>=1.35
|
|
35
|
+
Requires-Dist: opentelemetry-sdk>=1.35
|
|
36
|
+
Requires-Dist: portpicker
|
|
37
|
+
Requires-Dist: psutil
|
|
38
|
+
Requires-Dist: pydantic>=2.11
|
|
39
|
+
Requires-Dist: rich
|
|
40
|
+
Requires-Dist: setproctitle
|
|
41
|
+
Requires-Dist: uvicorn
|
|
42
|
+
Requires-Dist: uvicorn-worker
|
|
43
|
+
Provides-Extra: apo
|
|
44
|
+
Requires-Dist: poml; extra == 'apo'
|
|
45
|
+
Provides-Extra: mongo
|
|
46
|
+
Requires-Dist: pymongo; extra == 'mongo'
|
|
47
|
+
Provides-Extra: verl
|
|
48
|
+
Requires-Dist: verl>=0.5.0; extra == 'verl'
|
|
49
|
+
Requires-Dist: vllm>=0.8.4; extra == 'verl'
|
|
50
|
+
Provides-Extra: weave
|
|
51
|
+
Requires-Dist: weave>=0.52.22; extra == 'weave'
|
|
52
|
+
Description-Content-Type: text/markdown
|
|
53
|
+
|
|
54
|
+
# Mantisdk
|
|
55
|
+
|
|
56
|
+
[](https://badge.fury.io/py/mantisdk)
|
|
57
|
+
[](LICENSE)
|
|
58
|
+
|
|
59
|
+
**AI Agent Training and Evaluation Platform**
|
|
60
|
+
|
|
61
|
+
Mantisdk is a comprehensive toolkit for training and evaluating AI agents using reinforcement learning, automatic prompt optimization, and supervised fine-tuning.
|
|
62
|
+
|
|
63
|
+
## Core Features
|
|
64
|
+
|
|
65
|
+
- Turn your agent into an optimizable beast with **minimal code changes**
|
|
66
|
+
- Build with **any** agent framework (LangChain, OpenAI Agent SDK, AutoGen, CrewAI, and more)
|
|
67
|
+
- **Selectively** optimize one or more agents in a multi-agent system
|
|
68
|
+
- Embraces **algorithms** like Reinforcement Learning, Automatic Prompt Optimization, Supervised Fine-tuning and more
|
|
69
|
+
|
|
70
|
+
## Installation
|
|
71
|
+
|
|
72
|
+
```bash
|
|
73
|
+
pip install mantisdk
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
For optional dependencies:
|
|
77
|
+
|
|
78
|
+
```bash
|
|
79
|
+
# For APO (Automatic Prompt Optimization)
|
|
80
|
+
pip install mantisdk[apo]
|
|
81
|
+
|
|
82
|
+
# For VERL integration
|
|
83
|
+
pip install mantisdk[verl]
|
|
84
|
+
|
|
85
|
+
# For Weave integration
|
|
86
|
+
pip install mantisdk[weave]
|
|
87
|
+
|
|
88
|
+
# For MongoDB store
|
|
89
|
+
pip install mantisdk[mongo]
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
## Quick Start
|
|
93
|
+
|
|
94
|
+
```python
|
|
95
|
+
import mantisdk as msk
|
|
96
|
+
|
|
97
|
+
# Initialize the client
|
|
98
|
+
client = msk.MantisdkClient()
|
|
99
|
+
|
|
100
|
+
# Your agent code here...
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
## CLI Usage
|
|
104
|
+
|
|
105
|
+
```bash
|
|
106
|
+
# Start the Mantisdk server
|
|
107
|
+
msk store serve
|
|
108
|
+
|
|
109
|
+
# Run with vLLM
|
|
110
|
+
msk vllm start
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
## Documentation
|
|
114
|
+
|
|
115
|
+
For full documentation, visit [https://withmetis.github.io/mantis/mantisdk/](https://withmetis.github.io/mantis/mantisdk/)
|
|
116
|
+
|
|
117
|
+
## License
|
|
118
|
+
|
|
119
|
+
MIT License - see [LICENSE](LICENSE) for details.
|