mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,44 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ # type: ignore
4
+
5
+ import torch
6
+ from datasets import Dataset as HuggingFaceDataset
7
+ from omegaconf import DictConfig
8
+ from verl.utils.dataset.rl_dataset import RLHFDataset
9
+
10
+ from mantisdk.types import Dataset
11
+
12
+ __all__ = [
13
+ "AgentDataset",
14
+ "LoadedDataset",
15
+ ]
16
+
17
+
18
+ class AgentDataset(RLHFDataset):
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+
23
+ self.filter_overlong_prompts = False
24
+
25
+ def __getitem__(self, item):
26
+ row_dict: dict = self.dataframe[item]
27
+
28
+ # add index for each prompt
29
+ index = row_dict.get("extra_info", {}).get("index", 0)
30
+ row_dict["index"] = index
31
+ # Workaround for data proto. At least one tensor is needed.
32
+ row_dict["fake_ids"] = torch.ones(1, dtype=torch.int)
33
+ return row_dict
34
+
35
+
36
+ class LoadedDataset(AgentDataset):
37
+
38
+ def __init__(self, dataset: Dataset):
39
+ super().__init__([], None, DictConfig({})) # type: ignore
40
+ dataset_copy = [dataset[i] for i in range(len(dataset))]
41
+ self.dataframe = HuggingFaceDataset.from_list(dataset_copy)
42
+
43
+ def _read_files_and_tokenize(self):
44
+ pass
@@ -0,0 +1,248 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ # pyright: reportUnknownVariableType=false
4
+ # pyright: reportUnknownMemberType=false
5
+ # pyright: reportUnknownArgumentType=false
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING, Any, Type
10
+
11
+ import hydra
12
+ import ray
13
+ from ray.actor import ActorClass
14
+ from verl.trainer.main_ppo import create_rl_sampler
15
+ from verl.trainer.ppo.reward import load_reward_manager
16
+
17
+ from mantisdk.adapter import TraceAdapter
18
+ from mantisdk.llm_proxy import LLMProxy
19
+ from mantisdk.store.base import LightningStore
20
+ from mantisdk.types import Dataset
21
+
22
+ from .dataset import AgentDataset, LoadedDataset
23
+
24
+ if TYPE_CHECKING:
25
+ from .daemon import AgentModeDaemon
26
+ from .trainer import MantisdkTrainer
27
+
28
+ __all__ = [
29
+ "main",
30
+ "run_ppo",
31
+ "TaskRunner",
32
+ ]
33
+
34
+
35
+ @hydra.main(config_path="pkg://mantisdk/verl", config_name="config", version_base=None)
36
+ def main(config: Any):
37
+ from .daemon import AgentModeDaemon
38
+ from .trainer import MantisdkTrainer
39
+
40
+ run_ppo(
41
+ config,
42
+ train_dataset=None,
43
+ val_dataset=None,
44
+ store=None,
45
+ llm_proxy=None,
46
+ adapter=None,
47
+ trainer_cls=MantisdkTrainer,
48
+ daemon_cls=AgentModeDaemon,
49
+ )
50
+
51
+
52
+ def run_ppo(
53
+ config: Any,
54
+ train_dataset: Dataset[Any] | None,
55
+ val_dataset: Dataset[Any] | None,
56
+ store: LightningStore | None,
57
+ llm_proxy: LLMProxy | None,
58
+ adapter: TraceAdapter[Any] | None,
59
+ trainer_cls: Type[MantisdkTrainer],
60
+ daemon_cls: Type[AgentModeDaemon],
61
+ ) -> None:
62
+ if not ray.is_initialized():
63
+ # this is for local ray cluster
64
+ try:
65
+ # verl >= 0.6.0
66
+ num_cpus = config.ray_kwargs.ray_init.num_cpus
67
+ except AttributeError:
68
+ # verl < 0.6.0
69
+ num_cpus = config.ray_init.num_cpus
70
+ ray.init(
71
+ runtime_env={
72
+ "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
73
+ },
74
+ num_cpus=num_cpus,
75
+ )
76
+
77
+ runner = TaskRunner.remote()
78
+ ray.get(
79
+ runner.run.remote( # type: ignore
80
+ config=config,
81
+ train_dataset=train_dataset,
82
+ val_dataset=val_dataset,
83
+ store=store,
84
+ llm_proxy=llm_proxy,
85
+ adapter=adapter,
86
+ trainer_cls=trainer_cls,
87
+ daemon_cls=daemon_cls,
88
+ )
89
+ )
90
+
91
+
92
+ @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
93
+ class TaskRunner:
94
+ def run(
95
+ self,
96
+ config: Any,
97
+ train_dataset: Dataset[Any] | None,
98
+ val_dataset: Dataset[Any] | None,
99
+ store: LightningStore | None,
100
+ llm_proxy: LLMProxy | None,
101
+ adapter: TraceAdapter[Any] | None,
102
+ trainer_cls: Type[MantisdkTrainer],
103
+ daemon_cls: Type[AgentModeDaemon],
104
+ ):
105
+ # print initial config
106
+ from pprint import pprint
107
+
108
+ from omegaconf import OmegaConf
109
+ from verl.utils.fs import copy_to_local
110
+
111
+ pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
112
+ OmegaConf.resolve(config)
113
+
114
+ # download the checkpoint from hdfs
115
+ local_path = copy_to_local(config.actor_rollout_ref.model.path)
116
+
117
+ # instantiate tokenizer
118
+ from verl.utils.tokenizer import hf_processor, hf_tokenizer
119
+
120
+ trust_remote_code = config.data.get("trust_remote_code", False)
121
+ tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
122
+ processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
123
+
124
+ # define worker classes
125
+ if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
126
+ assert config.critic.strategy in ["fsdp", "fsdp2"]
127
+ from verl.single_controller.ray import RayWorkerGroup
128
+ from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
129
+
130
+ actor_rollout_cls = (
131
+ AsyncActorRolloutRefWorker
132
+ if config.actor_rollout_ref.rollout.mode == "async"
133
+ else ActorRolloutRefWorker
134
+ )
135
+ ray_worker_group_cls = RayWorkerGroup
136
+
137
+ elif config.actor_rollout_ref.actor.strategy == "megatron":
138
+ assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
139
+ # FIXME: This import is outdated
140
+ from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup # type: ignore
141
+ from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
142
+
143
+ actor_rollout_cls = ActorRolloutRefWorker
144
+ ray_worker_group_cls = NVMegatronRayWorkerGroup
145
+
146
+ else:
147
+ raise NotImplementedError
148
+
149
+ from verl.trainer.ppo.ray_trainer import ResourcePoolManager
150
+
151
+ try:
152
+ # verl >= 0.6.0
153
+ from verl.trainer.ppo.utils import Role
154
+ except ImportError:
155
+ # Fallback for verl <= 0.5.0
156
+ from verl.trainer.ppo.ray_trainer import Role # type: ignore
157
+
158
+ role_worker_mapping: dict[Role, ActorClass[Any]] = {
159
+ Role.ActorRollout: ray.remote(actor_rollout_cls),
160
+ Role.Critic: ray.remote(CriticWorker),
161
+ }
162
+
163
+ global_pool_id = "global_pool"
164
+ resource_pool_spec = {
165
+ global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
166
+ }
167
+ mapping = {
168
+ Role.ActorRollout: global_pool_id,
169
+ Role.Critic: global_pool_id,
170
+ }
171
+
172
+ # we should adopt a multi-source reward function here
173
+ # - for rule-based rm, we directly call a reward score
174
+ # - for model-based rm, we call a model
175
+ # - for code related prompt, we send to a sandbox if there are test cases
176
+ # - finally, we combine all the rewards together
177
+ # - The reward type depends on the tag of the data
178
+ if config.reward_model.enable:
179
+ if config.reward_model.strategy in ["fsdp", "fsdp2"]:
180
+ from verl.workers.fsdp_workers import RewardModelWorker
181
+ elif config.reward_model.strategy == "megatron":
182
+ from verl.workers.megatron_workers import RewardModelWorker
183
+ else:
184
+ raise NotImplementedError
185
+ role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
186
+ mapping[Role.RewardModel] = global_pool_id
187
+
188
+ # use reference model
189
+ if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
190
+ role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
191
+ mapping[Role.RefPolicy] = global_pool_id
192
+
193
+ reward_fn = load_reward_manager(
194
+ config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
195
+ )
196
+ val_reward_fn = load_reward_manager(
197
+ config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
198
+ )
199
+ resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
200
+
201
+ from verl.utils.dataset.rl_dataset import collate_fn
202
+
203
+ # Use our special dataset
204
+ if train_dataset is None:
205
+ train_dataset = AgentDataset(
206
+ data_files=config.data.train_files,
207
+ tokenizer=tokenizer,
208
+ processor=processor,
209
+ config=config.data,
210
+ )
211
+ else:
212
+ train_dataset = LoadedDataset(train_dataset)
213
+
214
+ if val_dataset is None:
215
+ val_dataset = AgentDataset(
216
+ data_files=config.data.val_files,
217
+ tokenizer=tokenizer,
218
+ processor=processor,
219
+ config=config.data,
220
+ )
221
+ else:
222
+ val_dataset = LoadedDataset(val_dataset)
223
+
224
+ train_sampler = create_rl_sampler(config.data, train_dataset)
225
+ trainer = trainer_cls(
226
+ config=config,
227
+ tokenizer=tokenizer,
228
+ processor=processor,
229
+ role_worker_mapping=role_worker_mapping,
230
+ resource_pool_manager=resource_pool_manager,
231
+ ray_worker_group_cls=ray_worker_group_cls,
232
+ reward_fn=reward_fn,
233
+ val_reward_fn=val_reward_fn,
234
+ train_dataset=train_dataset,
235
+ val_dataset=val_dataset,
236
+ collate_fn=collate_fn,
237
+ train_sampler=train_sampler,
238
+ store=store,
239
+ llm_proxy=llm_proxy,
240
+ adapter=adapter,
241
+ daemon_cls=daemon_cls,
242
+ )
243
+ trainer.init_workers()
244
+ trainer.fit()
245
+
246
+
247
+ if __name__ == "__main__":
248
+ main()