mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
mantisdk/verl/dataset.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# type: ignore
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from datasets import Dataset as HuggingFaceDataset
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from verl.utils.dataset.rl_dataset import RLHFDataset
|
|
9
|
+
|
|
10
|
+
from mantisdk.types import Dataset
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"AgentDataset",
|
|
14
|
+
"LoadedDataset",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AgentDataset(RLHFDataset):
|
|
19
|
+
|
|
20
|
+
def __init__(self, *args, **kwargs):
|
|
21
|
+
super().__init__(*args, **kwargs)
|
|
22
|
+
|
|
23
|
+
self.filter_overlong_prompts = False
|
|
24
|
+
|
|
25
|
+
def __getitem__(self, item):
|
|
26
|
+
row_dict: dict = self.dataframe[item]
|
|
27
|
+
|
|
28
|
+
# add index for each prompt
|
|
29
|
+
index = row_dict.get("extra_info", {}).get("index", 0)
|
|
30
|
+
row_dict["index"] = index
|
|
31
|
+
# Workaround for data proto. At least one tensor is needed.
|
|
32
|
+
row_dict["fake_ids"] = torch.ones(1, dtype=torch.int)
|
|
33
|
+
return row_dict
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class LoadedDataset(AgentDataset):
|
|
37
|
+
|
|
38
|
+
def __init__(self, dataset: Dataset):
|
|
39
|
+
super().__init__([], None, DictConfig({})) # type: ignore
|
|
40
|
+
dataset_copy = [dataset[i] for i in range(len(dataset))]
|
|
41
|
+
self.dataframe = HuggingFaceDataset.from_list(dataset_copy)
|
|
42
|
+
|
|
43
|
+
def _read_files_and_tokenize(self):
|
|
44
|
+
pass
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# pyright: reportUnknownVariableType=false
|
|
4
|
+
# pyright: reportUnknownMemberType=false
|
|
5
|
+
# pyright: reportUnknownArgumentType=false
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Type
|
|
10
|
+
|
|
11
|
+
import hydra
|
|
12
|
+
import ray
|
|
13
|
+
from ray.actor import ActorClass
|
|
14
|
+
from verl.trainer.main_ppo import create_rl_sampler
|
|
15
|
+
from verl.trainer.ppo.reward import load_reward_manager
|
|
16
|
+
|
|
17
|
+
from mantisdk.adapter import TraceAdapter
|
|
18
|
+
from mantisdk.llm_proxy import LLMProxy
|
|
19
|
+
from mantisdk.store.base import LightningStore
|
|
20
|
+
from mantisdk.types import Dataset
|
|
21
|
+
|
|
22
|
+
from .dataset import AgentDataset, LoadedDataset
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from .daemon import AgentModeDaemon
|
|
26
|
+
from .trainer import MantisdkTrainer
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"main",
|
|
30
|
+
"run_ppo",
|
|
31
|
+
"TaskRunner",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@hydra.main(config_path="pkg://mantisdk/verl", config_name="config", version_base=None)
|
|
36
|
+
def main(config: Any):
|
|
37
|
+
from .daemon import AgentModeDaemon
|
|
38
|
+
from .trainer import MantisdkTrainer
|
|
39
|
+
|
|
40
|
+
run_ppo(
|
|
41
|
+
config,
|
|
42
|
+
train_dataset=None,
|
|
43
|
+
val_dataset=None,
|
|
44
|
+
store=None,
|
|
45
|
+
llm_proxy=None,
|
|
46
|
+
adapter=None,
|
|
47
|
+
trainer_cls=MantisdkTrainer,
|
|
48
|
+
daemon_cls=AgentModeDaemon,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def run_ppo(
|
|
53
|
+
config: Any,
|
|
54
|
+
train_dataset: Dataset[Any] | None,
|
|
55
|
+
val_dataset: Dataset[Any] | None,
|
|
56
|
+
store: LightningStore | None,
|
|
57
|
+
llm_proxy: LLMProxy | None,
|
|
58
|
+
adapter: TraceAdapter[Any] | None,
|
|
59
|
+
trainer_cls: Type[MantisdkTrainer],
|
|
60
|
+
daemon_cls: Type[AgentModeDaemon],
|
|
61
|
+
) -> None:
|
|
62
|
+
if not ray.is_initialized():
|
|
63
|
+
# this is for local ray cluster
|
|
64
|
+
try:
|
|
65
|
+
# verl >= 0.6.0
|
|
66
|
+
num_cpus = config.ray_kwargs.ray_init.num_cpus
|
|
67
|
+
except AttributeError:
|
|
68
|
+
# verl < 0.6.0
|
|
69
|
+
num_cpus = config.ray_init.num_cpus
|
|
70
|
+
ray.init(
|
|
71
|
+
runtime_env={
|
|
72
|
+
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
|
|
73
|
+
},
|
|
74
|
+
num_cpus=num_cpus,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
runner = TaskRunner.remote()
|
|
78
|
+
ray.get(
|
|
79
|
+
runner.run.remote( # type: ignore
|
|
80
|
+
config=config,
|
|
81
|
+
train_dataset=train_dataset,
|
|
82
|
+
val_dataset=val_dataset,
|
|
83
|
+
store=store,
|
|
84
|
+
llm_proxy=llm_proxy,
|
|
85
|
+
adapter=adapter,
|
|
86
|
+
trainer_cls=trainer_cls,
|
|
87
|
+
daemon_cls=daemon_cls,
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
|
|
93
|
+
class TaskRunner:
|
|
94
|
+
def run(
|
|
95
|
+
self,
|
|
96
|
+
config: Any,
|
|
97
|
+
train_dataset: Dataset[Any] | None,
|
|
98
|
+
val_dataset: Dataset[Any] | None,
|
|
99
|
+
store: LightningStore | None,
|
|
100
|
+
llm_proxy: LLMProxy | None,
|
|
101
|
+
adapter: TraceAdapter[Any] | None,
|
|
102
|
+
trainer_cls: Type[MantisdkTrainer],
|
|
103
|
+
daemon_cls: Type[AgentModeDaemon],
|
|
104
|
+
):
|
|
105
|
+
# print initial config
|
|
106
|
+
from pprint import pprint
|
|
107
|
+
|
|
108
|
+
from omegaconf import OmegaConf
|
|
109
|
+
from verl.utils.fs import copy_to_local
|
|
110
|
+
|
|
111
|
+
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
|
112
|
+
OmegaConf.resolve(config)
|
|
113
|
+
|
|
114
|
+
# download the checkpoint from hdfs
|
|
115
|
+
local_path = copy_to_local(config.actor_rollout_ref.model.path)
|
|
116
|
+
|
|
117
|
+
# instantiate tokenizer
|
|
118
|
+
from verl.utils.tokenizer import hf_processor, hf_tokenizer
|
|
119
|
+
|
|
120
|
+
trust_remote_code = config.data.get("trust_remote_code", False)
|
|
121
|
+
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
|
122
|
+
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
|
|
123
|
+
|
|
124
|
+
# define worker classes
|
|
125
|
+
if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
|
|
126
|
+
assert config.critic.strategy in ["fsdp", "fsdp2"]
|
|
127
|
+
from verl.single_controller.ray import RayWorkerGroup
|
|
128
|
+
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
|
|
129
|
+
|
|
130
|
+
actor_rollout_cls = (
|
|
131
|
+
AsyncActorRolloutRefWorker
|
|
132
|
+
if config.actor_rollout_ref.rollout.mode == "async"
|
|
133
|
+
else ActorRolloutRefWorker
|
|
134
|
+
)
|
|
135
|
+
ray_worker_group_cls = RayWorkerGroup
|
|
136
|
+
|
|
137
|
+
elif config.actor_rollout_ref.actor.strategy == "megatron":
|
|
138
|
+
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
|
139
|
+
# FIXME: This import is outdated
|
|
140
|
+
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup # type: ignore
|
|
141
|
+
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
|
|
142
|
+
|
|
143
|
+
actor_rollout_cls = ActorRolloutRefWorker
|
|
144
|
+
ray_worker_group_cls = NVMegatronRayWorkerGroup
|
|
145
|
+
|
|
146
|
+
else:
|
|
147
|
+
raise NotImplementedError
|
|
148
|
+
|
|
149
|
+
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
# verl >= 0.6.0
|
|
153
|
+
from verl.trainer.ppo.utils import Role
|
|
154
|
+
except ImportError:
|
|
155
|
+
# Fallback for verl <= 0.5.0
|
|
156
|
+
from verl.trainer.ppo.ray_trainer import Role # type: ignore
|
|
157
|
+
|
|
158
|
+
role_worker_mapping: dict[Role, ActorClass[Any]] = {
|
|
159
|
+
Role.ActorRollout: ray.remote(actor_rollout_cls),
|
|
160
|
+
Role.Critic: ray.remote(CriticWorker),
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
global_pool_id = "global_pool"
|
|
164
|
+
resource_pool_spec = {
|
|
165
|
+
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
|
166
|
+
}
|
|
167
|
+
mapping = {
|
|
168
|
+
Role.ActorRollout: global_pool_id,
|
|
169
|
+
Role.Critic: global_pool_id,
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
# we should adopt a multi-source reward function here
|
|
173
|
+
# - for rule-based rm, we directly call a reward score
|
|
174
|
+
# - for model-based rm, we call a model
|
|
175
|
+
# - for code related prompt, we send to a sandbox if there are test cases
|
|
176
|
+
# - finally, we combine all the rewards together
|
|
177
|
+
# - The reward type depends on the tag of the data
|
|
178
|
+
if config.reward_model.enable:
|
|
179
|
+
if config.reward_model.strategy in ["fsdp", "fsdp2"]:
|
|
180
|
+
from verl.workers.fsdp_workers import RewardModelWorker
|
|
181
|
+
elif config.reward_model.strategy == "megatron":
|
|
182
|
+
from verl.workers.megatron_workers import RewardModelWorker
|
|
183
|
+
else:
|
|
184
|
+
raise NotImplementedError
|
|
185
|
+
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
|
|
186
|
+
mapping[Role.RewardModel] = global_pool_id
|
|
187
|
+
|
|
188
|
+
# use reference model
|
|
189
|
+
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
|
|
190
|
+
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
|
|
191
|
+
mapping[Role.RefPolicy] = global_pool_id
|
|
192
|
+
|
|
193
|
+
reward_fn = load_reward_manager(
|
|
194
|
+
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
|
|
195
|
+
)
|
|
196
|
+
val_reward_fn = load_reward_manager(
|
|
197
|
+
config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
|
|
198
|
+
)
|
|
199
|
+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
|
200
|
+
|
|
201
|
+
from verl.utils.dataset.rl_dataset import collate_fn
|
|
202
|
+
|
|
203
|
+
# Use our special dataset
|
|
204
|
+
if train_dataset is None:
|
|
205
|
+
train_dataset = AgentDataset(
|
|
206
|
+
data_files=config.data.train_files,
|
|
207
|
+
tokenizer=tokenizer,
|
|
208
|
+
processor=processor,
|
|
209
|
+
config=config.data,
|
|
210
|
+
)
|
|
211
|
+
else:
|
|
212
|
+
train_dataset = LoadedDataset(train_dataset)
|
|
213
|
+
|
|
214
|
+
if val_dataset is None:
|
|
215
|
+
val_dataset = AgentDataset(
|
|
216
|
+
data_files=config.data.val_files,
|
|
217
|
+
tokenizer=tokenizer,
|
|
218
|
+
processor=processor,
|
|
219
|
+
config=config.data,
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
val_dataset = LoadedDataset(val_dataset)
|
|
223
|
+
|
|
224
|
+
train_sampler = create_rl_sampler(config.data, train_dataset)
|
|
225
|
+
trainer = trainer_cls(
|
|
226
|
+
config=config,
|
|
227
|
+
tokenizer=tokenizer,
|
|
228
|
+
processor=processor,
|
|
229
|
+
role_worker_mapping=role_worker_mapping,
|
|
230
|
+
resource_pool_manager=resource_pool_manager,
|
|
231
|
+
ray_worker_group_cls=ray_worker_group_cls,
|
|
232
|
+
reward_fn=reward_fn,
|
|
233
|
+
val_reward_fn=val_reward_fn,
|
|
234
|
+
train_dataset=train_dataset,
|
|
235
|
+
val_dataset=val_dataset,
|
|
236
|
+
collate_fn=collate_fn,
|
|
237
|
+
train_sampler=train_sampler,
|
|
238
|
+
store=store,
|
|
239
|
+
llm_proxy=llm_proxy,
|
|
240
|
+
adapter=adapter,
|
|
241
|
+
daemon_cls=daemon_cls,
|
|
242
|
+
)
|
|
243
|
+
trainer.init_workers()
|
|
244
|
+
trainer.fit()
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
if __name__ == "__main__":
|
|
248
|
+
main()
|