@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,360 @@
1
+ """
2
+ Integration tests for Atropos RLAIF implementation
3
+
4
+ Tests:
5
+ 1. Module imports work correctly
6
+ 2. Data conversion functions work
7
+ 3. Reward functions produce valid outputs
8
+ 4. Environment can be instantiated (mock mode)
9
+ """
10
+
11
+ import pytest
12
+ from datetime import datetime
13
+ from typing import Dict
14
+
15
+ # Check for optional dependencies
16
+ try:
17
+ import torch # noqa: F401
18
+ HAS_TORCH = True
19
+ except ImportError:
20
+ HAS_TORCH = False
21
+
22
+ try:
23
+ import wandb # noqa: F401
24
+ HAS_WANDB = True
25
+ except ImportError:
26
+ HAS_WANDB = False
27
+
28
+ requires_torch = pytest.mark.skipif(not HAS_TORCH, reason="torch not installed")
29
+ requires_wandb = pytest.mark.skipif(not HAS_WANDB, reason="wandb not installed")
30
+
31
+
32
+ # Test imports work
33
+ class TestImports:
34
+ """Verify all modules can be imported"""
35
+
36
+ def test_import_models(self):
37
+ from src.models import (
38
+ TrainingTrajectory,
39
+ AtroposScoredGroup,
40
+ )
41
+ assert TrainingTrajectory is not None
42
+ assert AtroposScoredGroup is not None
43
+
44
+ def test_import_converter(self):
45
+ from src.data_bridge import (
46
+ TrajectoryToAtroposConverter,
47
+ ScoredGroupResult,
48
+ )
49
+ assert TrajectoryToAtroposConverter is not None
50
+ assert ScoredGroupResult is not None
51
+
52
+ def test_import_rewards(self):
53
+ from src.training.rewards import (
54
+ pnl_reward,
55
+ RewardNormalizer,
56
+ )
57
+ assert pnl_reward is not None
58
+ assert RewardNormalizer is not None
59
+
60
+ @requires_torch
61
+ def test_import_trainer(self):
62
+ from src.training import (
63
+ AtroposTrainer,
64
+ )
65
+ assert AtroposTrainer is not None
66
+
67
+ @requires_wandb
68
+ def test_import_environment(self):
69
+ from src.training import (
70
+ RLAIFEnv,
71
+ )
72
+ assert RLAIFEnv is not None
73
+
74
+
75
+ class TestRewardFunctions:
76
+ """Test reward calculation functions using archetype-aware API"""
77
+
78
+ def test_pnl_reward_positive(self):
79
+ from src.training.rewards import pnl_reward, TrajectoryRewardInputs
80
+
81
+ # Positive P&L should give positive reward
82
+ inputs = TrajectoryRewardInputs(
83
+ final_pnl=500.0,
84
+ starting_balance=10000.0,
85
+ end_balance=10500.0,
86
+ )
87
+ reward = pnl_reward(inputs)
88
+ assert reward > 0.0
89
+
90
+ def test_pnl_reward_negative(self):
91
+ from src.training.rewards import pnl_reward, TrajectoryRewardInputs
92
+
93
+ # Negative P&L should give negative reward
94
+ inputs = TrajectoryRewardInputs(
95
+ final_pnl=-500.0,
96
+ starting_balance=10000.0,
97
+ end_balance=9500.0,
98
+ )
99
+ reward = pnl_reward(inputs)
100
+ assert reward < 0.0
101
+
102
+ def test_pnl_reward_zero(self):
103
+ from src.training.rewards import pnl_reward, TrajectoryRewardInputs
104
+
105
+ # Zero P&L should give ~0
106
+ inputs = TrajectoryRewardInputs(
107
+ final_pnl=0.0,
108
+ starting_balance=10000.0,
109
+ end_balance=10000.0,
110
+ )
111
+ reward = pnl_reward(inputs)
112
+ assert -0.1 <= reward <= 0.1
113
+
114
+ def test_archetype_composite_reward(self):
115
+ from src.training.rewards import (
116
+ archetype_composite_reward,
117
+ TrajectoryRewardInputs,
118
+ BehaviorMetrics,
119
+ )
120
+
121
+ inputs = TrajectoryRewardInputs(
122
+ final_pnl=500.0,
123
+ starting_balance=10000.0,
124
+ end_balance=10500.0,
125
+ format_score=0.8,
126
+ reasoning_score=0.75,
127
+ )
128
+ behavior = BehaviorMetrics(
129
+ trades_executed=5,
130
+ total_pnl=500.0,
131
+ episode_length=10,
132
+ )
133
+
134
+ reward = archetype_composite_reward(inputs, "trader", behavior)
135
+ assert 0.0 <= reward <= 1.0
136
+
137
+ def test_composite_reward_with_inputs(self):
138
+ from src.training.rewards import composite_reward, TrajectoryRewardInputs
139
+
140
+ inputs = TrajectoryRewardInputs(
141
+ final_pnl=500.0,
142
+ starting_balance=10000.0,
143
+ end_balance=10500.0,
144
+ )
145
+
146
+ reward = composite_reward(inputs)
147
+ assert 0.0 <= reward <= 1.0
148
+
149
+ def test_relative_scores(self):
150
+ from src.training.rewards import relative_scores
151
+
152
+ # relative_scores expects a list of raw reward floats
153
+ rewards = [0.8, 0.5, 0.2] # High, medium, low rewards
154
+
155
+ scores = relative_scores(rewards)
156
+
157
+ # Should return normalized scores in [0, 1]
158
+ assert all(0.0 <= s <= 1.0 for s in scores)
159
+ # Best reward should have highest relative score
160
+ assert scores[0] > scores[1] > scores[2]
161
+
162
+ def test_reward_normalizer(self):
163
+ from src.training.rewards import RewardNormalizer
164
+
165
+ normalizer = RewardNormalizer()
166
+
167
+ # Update with some rewards
168
+ for r in [0.5, 0.6, 0.7, 0.8, 0.55, 0.65, 0.75, 0.85]:
169
+ normalizer.update(r)
170
+
171
+ # Normalize should work
172
+ normalized = normalizer.normalize(0.65)
173
+ assert isinstance(normalized, float)
174
+
175
+
176
+ class TestConverter:
177
+ """Test Trajectory to Atropos conversion"""
178
+
179
+ def create_sample_trajectory(self) -> Dict:
180
+ """Create a sample trajectory for testing"""
181
+ from src.models import (
182
+ TrainingTrajectory,
183
+ TrajectoryStep,
184
+ EnvironmentState,
185
+ Action,
186
+ LLMCall,
187
+ )
188
+
189
+ steps = []
190
+ for i in range(5):
191
+ step = TrajectoryStep(
192
+ step_number=i,
193
+ timestamp=1000000 + i * 1000,
194
+ environment_state=EnvironmentState(
195
+ agent_balance=10000.0 + i * 100,
196
+ agent_pnl=i * 100.0,
197
+ open_positions=i,
198
+ ),
199
+ provider_accesses=[],
200
+ llm_calls=[
201
+ LLMCall(
202
+ model="gpt-4",
203
+ system_prompt="You are a trading agent",
204
+ user_prompt=f"Market update {i}",
205
+ response=f"Action {i}",
206
+ temperature=0.7,
207
+ max_tokens=100,
208
+ purpose="action",
209
+ )
210
+ ],
211
+ action=Action(
212
+ action_type="trade",
213
+ parameters={"amount": 100},
214
+ success=True,
215
+ ),
216
+ reward=0.1,
217
+ )
218
+ steps.append(step)
219
+
220
+ return TrainingTrajectory(
221
+ id="test-1",
222
+ trajectory_id="traj-1",
223
+ agent_id="agent-1",
224
+ window_id="2024-01-01T00:00",
225
+ start_time=datetime.now(),
226
+ end_time=datetime.now(),
227
+ duration_ms=5000,
228
+ steps=steps,
229
+ total_reward=0.5,
230
+ final_pnl=400.0,
231
+ episode_length=5,
232
+ final_status="completed",
233
+ )
234
+
235
+ def test_convert_trajectory(self):
236
+ from src.data_bridge import TrajectoryToAtroposConverter
237
+
238
+ converter = TrajectoryToAtroposConverter()
239
+ traj = self.create_sample_trajectory()
240
+
241
+ result = converter.convert_trajectory(traj)
242
+
243
+ assert result is not None
244
+ assert len(result.messages) >= 3 # system + at least one exchange
245
+ assert result.metadata["trajectory_id"] == "traj-1"
246
+ assert result.metadata["final_pnl"] == 400.0
247
+
248
+ def test_convert_window_group(self):
249
+ from src.data_bridge import TrajectoryToAtroposConverter
250
+
251
+ converter = TrajectoryToAtroposConverter()
252
+ trajs = [self.create_sample_trajectory() for _ in range(4)]
253
+
254
+ # Modify trajectory IDs
255
+ for i, t in enumerate(trajs):
256
+ t.trajectory_id = f"traj-{i}"
257
+
258
+ result = converter.convert_window_group(trajs, None)
259
+
260
+ assert result.group_size == 4
261
+ assert len(result.scores) == 4
262
+ assert len(result.messages) == 4
263
+
264
+ def test_dropout(self):
265
+ from src.data_bridge import TrajectoryToAtroposConverter
266
+
267
+ # High dropout should skip some trajectories
268
+ converter = TrajectoryToAtroposConverter(dropout_rate=0.5)
269
+
270
+ dropped_count = 0
271
+ for _ in range(100):
272
+ traj = self.create_sample_trajectory()
273
+ result = converter.convert_trajectory(traj)
274
+ if result is None:
275
+ dropped_count += 1
276
+
277
+ # Should drop roughly 50%
278
+ assert 30 < dropped_count < 70
279
+
280
+
281
+ @requires_torch
282
+ class TestTrainerConfig:
283
+ """Test trainer configuration (requires torch)"""
284
+
285
+ def test_default_config(self):
286
+ from src.training import AtroposTrainingConfig
287
+
288
+ config = AtroposTrainingConfig()
289
+
290
+ assert config.model_name == "Qwen/Qwen2.5-3B-Instruct"
291
+ assert config.learning_rate == 1e-5
292
+ assert config.training_steps == 100
293
+
294
+ def test_custom_config(self):
295
+ from src.training import AtroposTrainingConfig
296
+
297
+ config = AtroposTrainingConfig(
298
+ model_name="Qwen/Qwen2.5-7B-Instruct",
299
+ training_steps=50,
300
+ learning_rate=5e-6,
301
+ )
302
+
303
+ assert config.model_name == "Qwen/Qwen2.5-7B-Instruct"
304
+ assert config.training_steps == 50
305
+ assert config.learning_rate == 5e-6
306
+
307
+
308
+ @requires_wandb
309
+ class TestEnvironmentConfig:
310
+ """Test environment configuration (requires wandb)"""
311
+
312
+ def test_default_config(self):
313
+ from src.training import RLAIFEnvConfig
314
+
315
+ config = RLAIFEnvConfig()
316
+
317
+ assert config.group_size == 4
318
+ assert config.lookback_hours == 72
319
+ assert config.min_agents_per_window == 2
320
+
321
+ def test_custom_config(self):
322
+ from src.training import RLAIFEnvConfig
323
+
324
+ config = RLAIFEnvConfig(
325
+ group_size=8,
326
+ lookback_hours=48,
327
+ judge_model="gpt-4",
328
+ )
329
+
330
+ assert config.group_size == 8
331
+ assert config.lookback_hours == 48
332
+ assert config.judge_model == "gpt-4"
333
+
334
+
335
+ class TestCalculateDropoutRate:
336
+ """Test dropout rate calculation"""
337
+
338
+ def test_no_dropout_needed(self):
339
+ from src.data_bridge import calculate_dropout_rate
340
+
341
+ rate = calculate_dropout_rate(500, target_trajectories=1000)
342
+ assert rate == 0.0
343
+
344
+ def test_dropout_needed(self):
345
+ from src.data_bridge import calculate_dropout_rate
346
+
347
+ rate = calculate_dropout_rate(2000, target_trajectories=1000)
348
+ assert 0.0 < rate <= 0.3
349
+
350
+ def test_max_dropout_cap(self):
351
+ from src.data_bridge import calculate_dropout_rate
352
+
353
+ rate = calculate_dropout_rate(10000, target_trajectories=1000, max_dropout=0.2)
354
+ assert rate == 0.2
355
+
356
+
357
+ # Run tests with: pytest tests/test_atropos_integration.py -v
358
+ if __name__ == "__main__":
359
+ pytest.main([__file__, "-v"])
360
+