crca 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (306) hide show
  1. CRCA.py +172 -7
  2. MODEL_CARD.md +53 -0
  3. PKG-INFO +8 -2
  4. RELEASE_NOTES.md +17 -0
  5. STABILITY.md +19 -0
  6. architecture/hybrid/consistency_engine.py +362 -0
  7. architecture/hybrid/conversation_manager.py +421 -0
  8. architecture/hybrid/explanation_generator.py +452 -0
  9. architecture/hybrid/few_shot_learner.py +533 -0
  10. architecture/hybrid/graph_compressor.py +286 -0
  11. architecture/hybrid/hybrid_agent.py +4398 -0
  12. architecture/hybrid/language_compiler.py +623 -0
  13. architecture/hybrid/main,py +0 -0
  14. architecture/hybrid/reasoning_tracker.py +322 -0
  15. architecture/hybrid/self_verifier.py +524 -0
  16. architecture/hybrid/task_decomposer.py +567 -0
  17. architecture/hybrid/text_corrector.py +341 -0
  18. benchmark_results/crca_core_benchmarks.json +178 -0
  19. branches/crca_sd/crca_sd_realtime.py +6 -2
  20. branches/general_agent/__init__.py +102 -0
  21. branches/general_agent/general_agent.py +1400 -0
  22. branches/general_agent/personality.py +169 -0
  23. branches/general_agent/utils/__init__.py +19 -0
  24. branches/general_agent/utils/prompt_builder.py +170 -0
  25. {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/METADATA +8 -2
  26. {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/RECORD +303 -20
  27. crca_core/__init__.py +35 -0
  28. crca_core/benchmarks/__init__.py +14 -0
  29. crca_core/benchmarks/synthetic_scm.py +103 -0
  30. crca_core/core/__init__.py +23 -0
  31. crca_core/core/api.py +120 -0
  32. crca_core/core/estimate.py +208 -0
  33. crca_core/core/godclass.py +72 -0
  34. crca_core/core/intervention_design.py +174 -0
  35. crca_core/core/lifecycle.py +48 -0
  36. crca_core/discovery/__init__.py +9 -0
  37. crca_core/discovery/tabular.py +193 -0
  38. crca_core/identify/__init__.py +171 -0
  39. crca_core/identify/backdoor.py +39 -0
  40. crca_core/identify/frontdoor.py +48 -0
  41. crca_core/identify/graph.py +106 -0
  42. crca_core/identify/id_algorithm.py +43 -0
  43. crca_core/identify/iv.py +48 -0
  44. crca_core/models/__init__.py +67 -0
  45. crca_core/models/provenance.py +56 -0
  46. crca_core/models/refusal.py +39 -0
  47. crca_core/models/result.py +83 -0
  48. crca_core/models/spec.py +151 -0
  49. crca_core/models/validation.py +68 -0
  50. crca_core/scm/__init__.py +9 -0
  51. crca_core/scm/linear_gaussian.py +198 -0
  52. crca_core/timeseries/__init__.py +6 -0
  53. crca_core/timeseries/pcmci.py +181 -0
  54. crca_llm/__init__.py +12 -0
  55. crca_llm/client.py +85 -0
  56. crca_llm/coauthor.py +118 -0
  57. crca_llm/orchestrator.py +289 -0
  58. crca_llm/types.py +21 -0
  59. crca_reasoning/__init__.py +16 -0
  60. crca_reasoning/critique.py +54 -0
  61. crca_reasoning/godclass.py +206 -0
  62. crca_reasoning/memory.py +24 -0
  63. crca_reasoning/rationale.py +10 -0
  64. crca_reasoning/react_controller.py +81 -0
  65. crca_reasoning/tool_router.py +97 -0
  66. crca_reasoning/types.py +40 -0
  67. crca_sd/__init__.py +15 -0
  68. crca_sd/crca_sd_core.py +2 -0
  69. crca_sd/crca_sd_governance.py +2 -0
  70. crca_sd/crca_sd_mpc.py +2 -0
  71. crca_sd/crca_sd_realtime.py +2 -0
  72. crca_sd/crca_sd_tui.py +2 -0
  73. cuda-keyring_1.1-1_all.deb +0 -0
  74. cuda-keyring_1.1-1_all.deb.1 +0 -0
  75. docs/IMAGE_ANNOTATION_USAGE.md +539 -0
  76. docs/INSTALL_DEEPSPEED.md +125 -0
  77. docs/api/branches/crca-cg.md +19 -0
  78. docs/api/branches/crca-q.md +27 -0
  79. docs/api/branches/crca-sd.md +37 -0
  80. docs/api/branches/general-agent.md +24 -0
  81. docs/api/branches/overview.md +19 -0
  82. docs/api/crca/agent-methods.md +62 -0
  83. docs/api/crca/operations.md +79 -0
  84. docs/api/crca/overview.md +32 -0
  85. docs/api/image-annotation/engine.md +52 -0
  86. docs/api/image-annotation/overview.md +17 -0
  87. docs/api/schemas/annotation.md +34 -0
  88. docs/api/schemas/core-schemas.md +82 -0
  89. docs/api/schemas/overview.md +32 -0
  90. docs/api/schemas/policy.md +30 -0
  91. docs/api/utils/conversation.md +22 -0
  92. docs/api/utils/graph-reasoner.md +32 -0
  93. docs/api/utils/overview.md +21 -0
  94. docs/api/utils/router.md +19 -0
  95. docs/api/utils/utilities.md +97 -0
  96. docs/architecture/causal-graphs.md +41 -0
  97. docs/architecture/data-flow.md +29 -0
  98. docs/architecture/design-principles.md +33 -0
  99. docs/architecture/hybrid-agent/components.md +38 -0
  100. docs/architecture/hybrid-agent/consistency.md +26 -0
  101. docs/architecture/hybrid-agent/overview.md +44 -0
  102. docs/architecture/hybrid-agent/reasoning.md +22 -0
  103. docs/architecture/llm-integration.md +26 -0
  104. docs/architecture/modular-structure.md +37 -0
  105. docs/architecture/overview.md +69 -0
  106. docs/architecture/policy-engine-arch.md +29 -0
  107. docs/branches/crca-cg/corposwarm.md +39 -0
  108. docs/branches/crca-cg/esg-scoring.md +30 -0
  109. docs/branches/crca-cg/multi-agent.md +35 -0
  110. docs/branches/crca-cg/overview.md +40 -0
  111. docs/branches/crca-q/alternative-data.md +55 -0
  112. docs/branches/crca-q/architecture.md +71 -0
  113. docs/branches/crca-q/backtesting.md +45 -0
  114. docs/branches/crca-q/causal-engine.md +33 -0
  115. docs/branches/crca-q/execution.md +39 -0
  116. docs/branches/crca-q/market-data.md +60 -0
  117. docs/branches/crca-q/overview.md +58 -0
  118. docs/branches/crca-q/philosophy.md +60 -0
  119. docs/branches/crca-q/portfolio-optimization.md +66 -0
  120. docs/branches/crca-q/risk-management.md +102 -0
  121. docs/branches/crca-q/setup.md +65 -0
  122. docs/branches/crca-q/signal-generation.md +61 -0
  123. docs/branches/crca-q/signal-validation.md +43 -0
  124. docs/branches/crca-sd/core.md +84 -0
  125. docs/branches/crca-sd/governance.md +53 -0
  126. docs/branches/crca-sd/mpc-solver.md +65 -0
  127. docs/branches/crca-sd/overview.md +59 -0
  128. docs/branches/crca-sd/realtime.md +28 -0
  129. docs/branches/crca-sd/tui.md +20 -0
  130. docs/branches/general-agent/overview.md +37 -0
  131. docs/branches/general-agent/personality.md +36 -0
  132. docs/branches/general-agent/prompt-builder.md +30 -0
  133. docs/changelog/index.md +79 -0
  134. docs/contributing/code-style.md +69 -0
  135. docs/contributing/documentation.md +43 -0
  136. docs/contributing/overview.md +29 -0
  137. docs/contributing/testing.md +29 -0
  138. docs/core/crcagent/async-operations.md +65 -0
  139. docs/core/crcagent/automatic-extraction.md +107 -0
  140. docs/core/crcagent/batch-prediction.md +80 -0
  141. docs/core/crcagent/bayesian-inference.md +60 -0
  142. docs/core/crcagent/causal-graph.md +92 -0
  143. docs/core/crcagent/counterfactuals.md +96 -0
  144. docs/core/crcagent/deterministic-simulation.md +78 -0
  145. docs/core/crcagent/dual-mode-operation.md +82 -0
  146. docs/core/crcagent/initialization.md +88 -0
  147. docs/core/crcagent/optimization.md +65 -0
  148. docs/core/crcagent/overview.md +63 -0
  149. docs/core/crcagent/time-series.md +57 -0
  150. docs/core/schemas/annotation.md +30 -0
  151. docs/core/schemas/core-schemas.md +82 -0
  152. docs/core/schemas/overview.md +30 -0
  153. docs/core/schemas/policy.md +41 -0
  154. docs/core/templates/base-agent.md +31 -0
  155. docs/core/templates/feature-mixins.md +31 -0
  156. docs/core/templates/overview.md +29 -0
  157. docs/core/templates/templates-guide.md +75 -0
  158. docs/core/tools/mcp-client.md +34 -0
  159. docs/core/tools/overview.md +24 -0
  160. docs/core/utils/conversation.md +27 -0
  161. docs/core/utils/graph-reasoner.md +29 -0
  162. docs/core/utils/overview.md +27 -0
  163. docs/core/utils/router.md +27 -0
  164. docs/core/utils/utilities.md +97 -0
  165. docs/css/custom.css +84 -0
  166. docs/examples/basic-usage.md +57 -0
  167. docs/examples/general-agent/general-agent-examples.md +50 -0
  168. docs/examples/hybrid-agent/hybrid-agent-examples.md +56 -0
  169. docs/examples/image-annotation/image-annotation-examples.md +54 -0
  170. docs/examples/integration/integration-examples.md +58 -0
  171. docs/examples/overview.md +37 -0
  172. docs/examples/trading/trading-examples.md +46 -0
  173. docs/features/causal-reasoning/advanced-topics.md +101 -0
  174. docs/features/causal-reasoning/counterfactuals.md +43 -0
  175. docs/features/causal-reasoning/do-calculus.md +50 -0
  176. docs/features/causal-reasoning/overview.md +47 -0
  177. docs/features/causal-reasoning/structural-models.md +52 -0
  178. docs/features/hybrid-agent/advanced-components.md +55 -0
  179. docs/features/hybrid-agent/core-components.md +64 -0
  180. docs/features/hybrid-agent/overview.md +34 -0
  181. docs/features/image-annotation/engine.md +82 -0
  182. docs/features/image-annotation/features.md +113 -0
  183. docs/features/image-annotation/integration.md +75 -0
  184. docs/features/image-annotation/overview.md +53 -0
  185. docs/features/image-annotation/quickstart.md +73 -0
  186. docs/features/policy-engine/doctrine-ledger.md +105 -0
  187. docs/features/policy-engine/monitoring.md +44 -0
  188. docs/features/policy-engine/mpc-control.md +89 -0
  189. docs/features/policy-engine/overview.md +46 -0
  190. docs/getting-started/configuration.md +225 -0
  191. docs/getting-started/first-agent.md +164 -0
  192. docs/getting-started/installation.md +144 -0
  193. docs/getting-started/quickstart.md +137 -0
  194. docs/index.md +118 -0
  195. docs/js/mathjax.js +13 -0
  196. docs/lrm/discovery_proof_notes.md +25 -0
  197. docs/lrm/finetune_full.md +83 -0
  198. docs/lrm/math_appendix.md +120 -0
  199. docs/lrm/overview.md +32 -0
  200. docs/mkdocs.yml +238 -0
  201. docs/stylesheets/extra.css +21 -0
  202. docs_generated/crca_core/CounterfactualResult.md +12 -0
  203. docs_generated/crca_core/DiscoveryHypothesisResult.md +13 -0
  204. docs_generated/crca_core/DraftSpec.md +13 -0
  205. docs_generated/crca_core/EstimateResult.md +13 -0
  206. docs_generated/crca_core/IdentificationResult.md +17 -0
  207. docs_generated/crca_core/InterventionDesignResult.md +12 -0
  208. docs_generated/crca_core/LockedSpec.md +15 -0
  209. docs_generated/crca_core/RefusalResult.md +12 -0
  210. docs_generated/crca_core/ValidationReport.md +9 -0
  211. docs_generated/crca_core/index.md +13 -0
  212. examples/general_agent_example.py +277 -0
  213. examples/general_agent_quickstart.py +202 -0
  214. examples/general_agent_simple.py +92 -0
  215. examples/hybrid_agent_auto_extraction.py +84 -0
  216. examples/hybrid_agent_dictionary_demo.py +104 -0
  217. examples/hybrid_agent_enhanced.py +179 -0
  218. examples/hybrid_agent_general_knowledge.py +107 -0
  219. examples/image_annotation_quickstart.py +328 -0
  220. examples/test_hybrid_fixes.py +77 -0
  221. image_annotation/__init__.py +27 -0
  222. image_annotation/annotation_engine.py +2593 -0
  223. install_cuda_wsl2.sh +59 -0
  224. install_deepspeed.sh +56 -0
  225. install_deepspeed_simple.sh +87 -0
  226. mkdocs.yml +252 -0
  227. ollama/Modelfile +8 -0
  228. prompts/__init__.py +2 -1
  229. prompts/default_crca.py +9 -1
  230. prompts/general_agent.py +227 -0
  231. prompts/image_annotation.py +56 -0
  232. pyproject.toml +17 -2
  233. requirements-docs.txt +10 -0
  234. requirements.txt +21 -2
  235. schemas/__init__.py +26 -1
  236. schemas/annotation.py +222 -0
  237. schemas/conversation.py +193 -0
  238. schemas/hybrid.py +211 -0
  239. schemas/reasoning.py +276 -0
  240. schemas_export/crca_core/CounterfactualResult.schema.json +108 -0
  241. schemas_export/crca_core/DiscoveryHypothesisResult.schema.json +113 -0
  242. schemas_export/crca_core/DraftSpec.schema.json +635 -0
  243. schemas_export/crca_core/EstimateResult.schema.json +113 -0
  244. schemas_export/crca_core/IdentificationResult.schema.json +145 -0
  245. schemas_export/crca_core/InterventionDesignResult.schema.json +111 -0
  246. schemas_export/crca_core/LockedSpec.schema.json +646 -0
  247. schemas_export/crca_core/RefusalResult.schema.json +90 -0
  248. schemas_export/crca_core/ValidationReport.schema.json +62 -0
  249. scripts/build_lrm_dataset.py +80 -0
  250. scripts/export_crca_core_schemas.py +54 -0
  251. scripts/export_hf_lrm.py +37 -0
  252. scripts/export_ollama_gguf.py +45 -0
  253. scripts/generate_changelog.py +157 -0
  254. scripts/generate_crca_core_docs_from_schemas.py +86 -0
  255. scripts/run_crca_core_benchmarks.py +163 -0
  256. scripts/run_full_finetune.py +198 -0
  257. scripts/run_lrm_eval.py +31 -0
  258. templates/graph_management.py +29 -0
  259. tests/conftest.py +9 -0
  260. tests/test_core.py +2 -3
  261. tests/test_crca_core_discovery_tabular.py +15 -0
  262. tests/test_crca_core_estimate_dowhy.py +36 -0
  263. tests/test_crca_core_identify.py +18 -0
  264. tests/test_crca_core_intervention_design.py +36 -0
  265. tests/test_crca_core_linear_gaussian_scm.py +69 -0
  266. tests/test_crca_core_spec.py +25 -0
  267. tests/test_crca_core_timeseries_pcmci.py +15 -0
  268. tests/test_crca_llm_coauthor.py +12 -0
  269. tests/test_crca_llm_orchestrator.py +80 -0
  270. tests/test_hybrid_agent_llm_enhanced.py +556 -0
  271. tests/test_image_annotation_demo.py +376 -0
  272. tests/test_image_annotation_operational.py +408 -0
  273. tests/test_image_annotation_unit.py +551 -0
  274. tests/test_training_moe.py +13 -0
  275. training/__init__.py +42 -0
  276. training/datasets.py +140 -0
  277. training/deepspeed_zero2_0_5b.json +22 -0
  278. training/deepspeed_zero2_1_5b.json +22 -0
  279. training/deepspeed_zero3_0_5b.json +28 -0
  280. training/deepspeed_zero3_14b.json +28 -0
  281. training/deepspeed_zero3_h100_3gpu.json +20 -0
  282. training/deepspeed_zero3_offload.json +28 -0
  283. training/eval.py +92 -0
  284. training/finetune.py +516 -0
  285. training/public_datasets.py +89 -0
  286. training_data/react_train.jsonl +7473 -0
  287. utils/agent_discovery.py +311 -0
  288. utils/batch_processor.py +317 -0
  289. utils/conversation.py +78 -0
  290. utils/edit_distance.py +118 -0
  291. utils/formatter.py +33 -0
  292. utils/graph_reasoner.py +530 -0
  293. utils/rate_limiter.py +283 -0
  294. utils/router.py +2 -2
  295. utils/tool_discovery.py +307 -0
  296. webui/__init__.py +10 -0
  297. webui/app.py +229 -0
  298. webui/config.py +104 -0
  299. webui/static/css/style.css +332 -0
  300. webui/static/js/main.js +284 -0
  301. webui/templates/index.html +42 -0
  302. tests/test_crca_excel.py +0 -166
  303. tests/test_data_broker.py +0 -424
  304. tests/test_palantir.py +0 -349
  305. {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/WHEEL +0 -0
  306. {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/licenses/LICENSE +0 -0
training/finetune.py ADDED
@@ -0,0 +1,516 @@
1
+ """Low-compute finetuning pipeline (LoRA/QLoRA when available)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Dict, Optional
9
+
10
+ import torch
11
+
12
+ # Disable DeepSpeed op building if CUDA_HOME not set (prevents MissingCUDAException)
13
+ if "CUDA_HOME" not in os.environ and "DS_BUILD_OPS" not in os.environ:
14
+ os.environ["DS_BUILD_OPS"] = "0"
15
+
16
+ MODEL_REGISTRY: Dict[str, Dict[str, object]] = {
17
+ "google/switch-base-8": {"arch": "seq2seq", "moe": True},
18
+ "google/switch-large-16": {"arch": "seq2seq", "moe": True},
19
+ }
20
+
21
+
22
+ def _resolve_model_info(base_model: str) -> Dict[str, object]:
23
+ model = base_model.lower()
24
+ for model_id, info in MODEL_REGISTRY.items():
25
+ if model == model_id.lower() or model.startswith(model_id.lower()):
26
+ return {"arch": info.get("arch", "causal"), "moe": info.get("moe", False)}
27
+ if "switch" in model:
28
+ return {"arch": "seq2seq", "moe": True}
29
+ return {"arch": "causal", "moe": False}
30
+
31
+
32
+ @dataclass
33
+ class FinetuneConfig:
34
+ base_model: str = "microsoft/phi-2"
35
+ output_dir: str = "lrm_finetune_out"
36
+ train_file: str = "training_data/react_train.jsonl"
37
+ eval_file: Optional[str] = None
38
+ num_train_epochs: int = 1
39
+ per_device_batch_size: int = 2
40
+ gradient_accumulation_steps: int = 1
41
+ learning_rate: float = 2e-4
42
+ use_lora: bool = True
43
+ max_seq_length: int = 512
44
+ gradient_checkpointing: bool = False
45
+ fp16: bool = True
46
+ bf16: bool = False
47
+ deepspeed_config: Optional[str] = None
48
+
49
+
50
+ def full_finetune_qwen25_1_5b_config() -> FinetuneConfig:
51
+ """
52
+ Full finetune configuration for Qwen2.5-1.5B-Instruct.
53
+
54
+ Aggressively optimized for CRCA reasoning:
55
+ - Higher learning rate (5e-4) for smaller models to avoid boilerplate
56
+ - Longer sequences (8192) to capture full reasoning chains
57
+ - Full finetune (no LoRA) for maximum reasoning capability
58
+ - DeepSpeed ZeRO-2 for memory efficiency
59
+ - 20 epochs for thorough convergence
60
+
61
+ Compatible with NVIDIA GPUs (e.g. H100 SXM).
62
+ """
63
+ return FinetuneConfig(
64
+ base_model="Qwen/Qwen2.5-1.5B-Instruct",
65
+ output_dir="lrm_qwen25_1_5b_full_finetune",
66
+ train_file="training_data/react_train.jsonl",
67
+ eval_file=None,
68
+ num_train_epochs=20, # More epochs for CRCA reasoning convergence
69
+ per_device_batch_size=8, # Cloud GPU optimized
70
+ gradient_accumulation_steps=16, # Effective batch size: 128
71
+ learning_rate=5e-4, # Higher LR for smaller models, aggressive for CRCA
72
+ use_lora=False, # Full finetune for maximum reasoning capability
73
+ max_seq_length=8192, # Longer sequences for reasoning chains
74
+ gradient_checkpointing=True, # Enable for memory efficiency
75
+ fp16=True,
76
+ bf16=False,
77
+ deepspeed_config="training/deepspeed_zero2_1_5b.json",
78
+ )
79
+
80
+
81
+ def full_finetune_qwen25_7b_config() -> FinetuneConfig:
82
+ """
83
+ Full finetune configuration for Qwen2.5-7B-Instruct.
84
+
85
+ Aggressively optimized for CRCA reasoning:
86
+ - Moderate learning rate (2e-4) for stability with reasoning tasks
87
+ - 20 epochs for thorough convergence on CRCA tasks
88
+ - Full finetune (no LoRA) for maximum reasoning capability
89
+ - DeepSpeed ZeRO-3 with CPU offload for memory efficiency
90
+ - BF16 for better numerical stability on larger models
91
+
92
+ Compatible with NVIDIA GPUs (e.g. H100 SXM).
93
+ """
94
+ return FinetuneConfig(
95
+ base_model="Qwen/Qwen2.5-7B-Instruct",
96
+ output_dir="lrm_qwen25_7b_full_finetune",
97
+ train_file="training_data/react_train.jsonl",
98
+ eval_file=None,
99
+ num_train_epochs=20, # Increased from 1 for CRCA reasoning convergence
100
+ per_device_batch_size=4, # Cloud GPU optimized
101
+ gradient_accumulation_steps=32, # Effective batch size: 128
102
+ learning_rate=2e-4, # Optimized for CRCA reasoning, higher than default
103
+ use_lora=False, # Full finetune for maximum reasoning capability
104
+ max_seq_length=4096, # Full context for reasoning chains
105
+ gradient_checkpointing=True, # Enable for memory efficiency
106
+ fp16=False,
107
+ bf16=True, # Better numerical stability for larger models
108
+ deepspeed_config="training/deepspeed_zero3_offload.json",
109
+ )
110
+
111
+
112
+ def full_finetune_qwen25_14b_config() -> FinetuneConfig:
113
+ """
114
+ Full finetune configuration for Qwen2.5-14B-Instruct.
115
+
116
+ Aggressively optimized for CRCA reasoning:
117
+ - Lower learning rate (1e-4) for stability on large models
118
+ - 20 epochs for thorough convergence on CRCA tasks
119
+ - Full finetune (no LoRA) for maximum reasoning capability
120
+ - DeepSpeed ZeRO-3 with CPU offload for memory efficiency
121
+ - BF16 required for numerical stability
122
+ - Longer gradient accumulation for effective batch size
123
+
124
+ Compatible with NVIDIA GPUs (e.g. H100 SXM).
125
+ """
126
+ return FinetuneConfig(
127
+ base_model="Qwen/Qwen2.5-14B-Instruct",
128
+ output_dir="lrm_qwen25_14b_full_finetune",
129
+ train_file="training_data/react_train.jsonl",
130
+ eval_file=None,
131
+ num_train_epochs=20, # Thorough convergence for CRCA reasoning
132
+ per_device_batch_size=2, # Cloud GPU optimized (memory constrained)
133
+ gradient_accumulation_steps=64, # Effective batch size: 128
134
+ learning_rate=1e-4, # Lower LR for stability, still aggressive for CRCA
135
+ use_lora=False, # Full finetune for maximum reasoning capability
136
+ max_seq_length=2048, # Memory constraints on 14B model
137
+ gradient_checkpointing=True, # Critical for memory efficiency
138
+ fp16=False,
139
+ bf16=True, # Required for numerical stability on large models
140
+ deepspeed_config="training/deepspeed_zero3_14b.json",
141
+ )
142
+
143
+
144
+ def full_finetune_switch_base_8_config() -> FinetuneConfig:
145
+ """
146
+ Full finetune configuration for Switch MoE base (Seq2Seq).
147
+
148
+ Optimized for Switch MoE (encoder-decoder):
149
+ - BF16 for H100 stability
150
+ - ZeRO-3 without CPU offload (H100-class GPUs)
151
+ - Moderate batch sizes for Seq2Seq memory footprint
152
+ """
153
+ return FinetuneConfig(
154
+ base_model="google/switch-base-8",
155
+ output_dir="lrm_switch_base_8_full_finetune",
156
+ train_file="training_data/react_train.jsonl",
157
+ eval_file=None,
158
+ num_train_epochs=10,
159
+ per_device_batch_size=4,
160
+ gradient_accumulation_steps=16,
161
+ learning_rate=2e-4,
162
+ use_lora=False,
163
+ max_seq_length=1024,
164
+ gradient_checkpointing=True,
165
+ fp16=False,
166
+ bf16=True,
167
+ deepspeed_config="training/deepspeed_zero3_h100_3gpu.json",
168
+ )
169
+
170
+
171
+ def full_finetune_switch_large_16_config() -> FinetuneConfig:
172
+ """
173
+ Full finetune configuration for Switch MoE large (Seq2Seq).
174
+
175
+ - BF16 for numerical stability
176
+ - ZeRO-3 without CPU offload (H100-class GPUs)
177
+ - Conservative batch sizes to keep memory stable
178
+ """
179
+ return FinetuneConfig(
180
+ base_model="google/switch-large-16",
181
+ output_dir="lrm_switch_large_16_full_finetune",
182
+ train_file="training_data/react_train.jsonl",
183
+ eval_file=None,
184
+ num_train_epochs=10,
185
+ per_device_batch_size=2,
186
+ gradient_accumulation_steps=32,
187
+ learning_rate=1e-4,
188
+ use_lora=False,
189
+ max_seq_length=1024,
190
+ gradient_checkpointing=True,
191
+ fp16=False,
192
+ bf16=True,
193
+ deepspeed_config="training/deepspeed_zero3_h100_3gpu.json",
194
+ )
195
+
196
+
197
+ def full_finetune_qwen25_0_5b_config_cloud() -> FinetuneConfig:
198
+ """
199
+ Cloud GPU optimized configuration for Qwen2.5-0.5B-Instruct.
200
+
201
+ For GPUs with 16GB+ VRAM (RTX 3090, A4000, A100, etc.):
202
+ - Much larger batch sizes
203
+ - Longer sequences
204
+ - Full finetune (no LoRA needed)
205
+ """
206
+ return FinetuneConfig(
207
+ base_model="Qwen/Qwen2.5-0.5B-Instruct",
208
+ output_dir="lrm_qwen25_0_5b_full_finetune",
209
+ train_file="training_data/react_train.jsonl",
210
+ eval_file=None,
211
+ num_train_epochs=20,
212
+ per_device_batch_size=16, # Cloud GPUs can handle this
213
+ gradient_accumulation_steps=8, # Adjusted for effective batch size
214
+ learning_rate=4e-4,
215
+ use_lora=False, # Full finetune on cloud GPU
216
+ max_seq_length=4096, # Full context length on cloud
217
+ gradient_checkpointing=True,
218
+ fp16=True,
219
+ bf16=False,
220
+ deepspeed_config=None, # Not needed on cloud GPUs
221
+ )
222
+
223
+
224
+ def full_finetune_qwen25_0_5b_config() -> FinetuneConfig:
225
+ """
226
+ Full finetune configuration for Qwen2.5-0.5B-Instruct.
227
+
228
+ Optimized for smaller model size:
229
+ - Larger batch sizes (0.5B fits easily in memory)
230
+ - Higher learning rates (smaller models can handle higher LRs)
231
+ - Reduced gradient accumulation (larger batch size means less accumulation needed)
232
+ - Uses Accelerate (simpler than DeepSpeed for 0.5B model)
233
+ """
234
+ return FinetuneConfig(
235
+ base_model="Qwen/Qwen2.5-0.5B-Instruct",
236
+ output_dir="lrm_qwen25_0_5b_full_finetune",
237
+ train_file="training_data/react_train.jsonl",
238
+ eval_file=None,
239
+ num_train_epochs=20,
240
+ per_device_batch_size=1, # Must be 1 for 4GB GPU
241
+ gradient_accumulation_steps=128, # Large accumulation to maintain effective batch size
242
+ learning_rate=4e-4, # Smaller models can handle higher learning rates
243
+ use_lora=True, # Use LoRA to avoid CPU offload - trains only ~1% of parameters
244
+ max_seq_length=512, # Must be 512 or less for 4GB GPU
245
+ gradient_checkpointing=False, # Not needed with LoRA + 8-bit
246
+ fp16=True,
247
+ bf16=False,
248
+ deepspeed_config=None, # No DeepSpeed needed with LoRA - stays on GPU
249
+ )
250
+
251
+
252
+ def run_finetune(cfg: FinetuneConfig) -> None:
253
+ """
254
+ Run finetuning on NVIDIA GPUs (e.g. H100 SXM).
255
+
256
+ Uses CUDA with NCCL for distributed training. Supports 4-bit/8-bit quantization
257
+ for LoRA and full finetune with BF16/FP16. DeepSpeed ZeRO-2/ZeRO-3 for multi-GPU.
258
+ """
259
+ print("CUDA (NVIDIA GPU) detected - using CUDA settings")
260
+
261
+ # Configure environment for single GPU DeepSpeed (if using DeepSpeed)
262
+ if cfg.deepspeed_config:
263
+ if "RANK" not in os.environ:
264
+ os.environ["RANK"] = "0"
265
+ if "LOCAL_RANK" not in os.environ:
266
+ os.environ["LOCAL_RANK"] = "0"
267
+ if "WORLD_SIZE" not in os.environ:
268
+ os.environ["WORLD_SIZE"] = "1"
269
+ if "MASTER_ADDR" not in os.environ:
270
+ os.environ["MASTER_ADDR"] = "localhost"
271
+ if "MASTER_PORT" not in os.environ:
272
+ os.environ["MASTER_PORT"] = "29500"
273
+
274
+ try:
275
+ from datasets import load_dataset # type: ignore
276
+ from transformers import (
277
+ AutoModelForCausalLM,
278
+ AutoModelForSeq2SeqLM,
279
+ AutoTokenizer,
280
+ DataCollatorForSeq2Seq,
281
+ Trainer,
282
+ TrainingArguments,
283
+ ) # type: ignore
284
+ except Exception as exc:
285
+ raise RuntimeError(
286
+ "Missing training dependencies. Install: transformers, datasets, accelerate, peft"
287
+ ) from exc
288
+
289
+ model_info = _resolve_model_info(cfg.base_model)
290
+ is_seq2seq = model_info.get("arch") == "seq2seq"
291
+ if model_info.get("moe"):
292
+ print("MoE model detected - using Seq2Seq pipeline")
293
+
294
+ # Load tokenizer with error handling
295
+ try:
296
+ tokenizer = AutoTokenizer.from_pretrained(cfg.base_model, trust_remote_code=True)
297
+ except Exception as exc:
298
+ raise RuntimeError(f"Failed to load tokenizer from {cfg.base_model}: {exc}") from exc
299
+
300
+ if tokenizer.pad_token is None:
301
+ tokenizer.pad_token = tokenizer.eos_token
302
+ if tokenizer.pad_token is None:
303
+ raise ValueError(f"Tokenizer from {cfg.base_model} has no pad_token or eos_token")
304
+
305
+ model_cls = AutoModelForSeq2SeqLM if is_seq2seq else AutoModelForCausalLM
306
+
307
+ # Load model (CUDA: 4-bit/8-bit for LoRA, full precision for full finetune)
308
+ if cfg.use_lora:
309
+ try:
310
+ from transformers import BitsAndBytesConfig
311
+ quantization_config = BitsAndBytesConfig(
312
+ load_in_4bit=True,
313
+ bnb_4bit_compute_dtype=torch.float16,
314
+ )
315
+ model = model_cls.from_pretrained(
316
+ cfg.base_model,
317
+ quantization_config=quantization_config,
318
+ device_map="auto",
319
+ )
320
+ print("Using 4-bit quantization (CUDA)")
321
+ except (ImportError, Exception):
322
+ try:
323
+ from transformers import BitsAndBytesConfig
324
+ quantization_config = BitsAndBytesConfig(
325
+ load_in_8bit=True,
326
+ bnb_8bit_compute_dtype=torch.float16,
327
+ )
328
+ model = model_cls.from_pretrained(
329
+ cfg.base_model,
330
+ quantization_config=quantization_config,
331
+ device_map="auto",
332
+ )
333
+ print("Using 8-bit quantization (4-bit not available)")
334
+ except (ImportError, Exception):
335
+ model = model_cls.from_pretrained(
336
+ cfg.base_model,
337
+ torch_dtype=torch.bfloat16 if cfg.bf16 else torch.float16,
338
+ low_cpu_mem_usage=True,
339
+ )
340
+ print("Using full precision (quantization not available)")
341
+ else:
342
+ # Full finetune: use BF16/FP16 based on config
343
+ model = model_cls.from_pretrained(
344
+ cfg.base_model,
345
+ torch_dtype=torch.bfloat16 if cfg.bf16 else torch.float16,
346
+ low_cpu_mem_usage=True,
347
+ )
348
+ precision_str = "BF16" if cfg.bf16 else "FP16"
349
+ print(f"Using full finetune with {precision_str} precision")
350
+
351
+ if cfg.use_lora:
352
+ try:
353
+ from peft import LoraConfig, get_peft_model # type: ignore
354
+ except Exception as exc:
355
+ raise RuntimeError("LoRA requested but peft not installed. Install peft.") from exc
356
+
357
+ lora = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
358
+ model = get_peft_model(model, lora)
359
+ # LoRA doesn't need gradient checkpointing - it's already memory efficient
360
+ elif cfg.gradient_checkpointing:
361
+ model.gradient_checkpointing_enable()
362
+
363
+ if not Path(cfg.train_file).exists():
364
+ raise FileNotFoundError(f"Training file not found: {cfg.train_file}")
365
+ if cfg.eval_file and not Path(cfg.eval_file).exists():
366
+ raise FileNotFoundError(f"Eval file not found: {cfg.eval_file}")
367
+
368
+ data_files = {"train": cfg.train_file}
369
+ if cfg.eval_file:
370
+ data_files["validation"] = cfg.eval_file
371
+
372
+ try:
373
+ dataset = load_dataset("json", data_files=data_files)
374
+ except Exception as exc:
375
+ raise RuntimeError(f"Failed to load dataset from {data_files}: {exc}") from exc
376
+
377
+ if "train" not in dataset:
378
+ raise ValueError(f"Dataset must contain 'train' split, got: {list(dataset.keys())}")
379
+ if len(dataset["train"]) == 0:
380
+ raise ValueError("Training dataset is empty")
381
+
382
+ def _tokenize(examples):
383
+ if is_seq2seq:
384
+ inputs = tokenizer(
385
+ examples["prompt"],
386
+ truncation=True,
387
+ padding="max_length",
388
+ max_length=cfg.max_seq_length,
389
+ return_tensors=None,
390
+ )
391
+ targets = tokenizer(
392
+ text_target=examples["response"],
393
+ truncation=True,
394
+ padding="max_length",
395
+ max_length=cfg.max_seq_length,
396
+ return_tensors=None,
397
+ )
398
+ pad_token_id = tokenizer.pad_token_id
399
+ labels = [
400
+ [token_id if token_id != pad_token_id else -100 for token_id in seq]
401
+ for seq in targets["input_ids"]
402
+ ]
403
+ inputs["labels"] = labels
404
+ return inputs
405
+
406
+ texts = [p + "\n" + r for p, r in zip(examples["prompt"], examples["response"])]
407
+ tokenized = tokenizer(
408
+ texts,
409
+ truncation=True,
410
+ padding="max_length",
411
+ max_length=cfg.max_seq_length,
412
+ return_tensors=None, # Return lists, not tensors
413
+ )
414
+ # For causal LM, labels are the same as input_ids
415
+ # Set padding tokens to -100 so they're ignored in loss calculation
416
+ labels = []
417
+ pad_token_id = tokenizer.pad_token_id
418
+ for input_ids in tokenized["input_ids"]:
419
+ label = [token_id if token_id != pad_token_id else -100 for token_id in input_ids]
420
+ labels.append(label)
421
+ tokenized["labels"] = labels
422
+ return tokenized
423
+
424
+ # Get column names before tokenization (handle both train and validation)
425
+ original_columns = dataset["train"].column_names
426
+
427
+ tokenized = dataset.map(
428
+ _tokenize,
429
+ batched=True,
430
+ remove_columns=original_columns,
431
+ )
432
+
433
+ # Validate configuration before training
434
+ if cfg.per_device_batch_size < 1:
435
+ raise ValueError(f"per_device_batch_size must be >= 1, got {cfg.per_device_batch_size}")
436
+ if cfg.gradient_accumulation_steps < 1:
437
+ raise ValueError(f"gradient_accumulation_steps must be >= 1, got {cfg.gradient_accumulation_steps}")
438
+ if cfg.learning_rate <= 0:
439
+ raise ValueError(f"learning_rate must be > 0, got {cfg.learning_rate}")
440
+ if cfg.max_seq_length < 1:
441
+ raise ValueError(f"max_seq_length must be >= 1, got {cfg.max_seq_length}")
442
+ if cfg.fp16 and cfg.bf16:
443
+ raise ValueError("Cannot use both fp16 and bf16 simultaneously")
444
+
445
+ # Resolve DeepSpeed config path if provided
446
+ deepspeed_config_path = None
447
+ if cfg.deepspeed_config:
448
+ deepspeed_config_path = str(Path(cfg.deepspeed_config).resolve())
449
+ if not Path(deepspeed_config_path).exists():
450
+ raise FileNotFoundError(f"DeepSpeed config file not found: {deepspeed_config_path}")
451
+
452
+ args = TrainingArguments(
453
+ output_dir=cfg.output_dir,
454
+ num_train_epochs=cfg.num_train_epochs,
455
+ per_device_train_batch_size=cfg.per_device_batch_size,
456
+ per_device_eval_batch_size=cfg.per_device_batch_size,
457
+ gradient_accumulation_steps=cfg.gradient_accumulation_steps,
458
+ learning_rate=cfg.learning_rate,
459
+ fp16=cfg.fp16,
460
+ bf16=cfg.bf16,
461
+ gradient_checkpointing=cfg.gradient_checkpointing if not cfg.use_lora else False, # LoRA doesn't need it
462
+ deepspeed=deepspeed_config_path if deepspeed_config_path else None,
463
+ logging_steps=50,
464
+ save_steps=200,
465
+ eval_strategy="no" if cfg.eval_file is None else "steps",
466
+ eval_steps=200 if cfg.eval_file else None, # Evaluate every 200 steps if eval_file provided
467
+ save_total_limit=2,
468
+ remove_unused_columns=False,
469
+ dataloader_pin_memory=False, # Save memory
470
+ dataloader_num_workers=0, # Reduce memory overhead
471
+ optim="adamw_torch", # Use standard AdamW (more memory efficient than fused variants)
472
+ max_grad_norm=1.0, # Gradient clipping
473
+ warmup_steps=100, # Add warmup for better convergence
474
+ lr_scheduler_type="cosine", # Cosine learning rate schedule for better convergence
475
+ )
476
+
477
+ train_dataset = tokenized["train"]
478
+ eval_dataset = tokenized.get("validation") if cfg.eval_file else None
479
+
480
+ if len(train_dataset) == 0:
481
+ raise ValueError("Tokenized training dataset is empty")
482
+
483
+ data_collator = None
484
+ if is_seq2seq:
485
+ data_collator = DataCollatorForSeq2Seq(
486
+ tokenizer=tokenizer,
487
+ model=model,
488
+ label_pad_token_id=-100,
489
+ )
490
+
491
+ # Data collator is optional when using max_length padding, but helps ensure consistency
492
+ trainer = Trainer(
493
+ model=model,
494
+ args=args,
495
+ train_dataset=train_dataset,
496
+ eval_dataset=eval_dataset,
497
+ data_collator=data_collator,
498
+ )
499
+
500
+ print(f"Starting training with {len(train_dataset)} examples")
501
+ if eval_dataset:
502
+ print(f"Evaluation dataset has {len(eval_dataset)} examples")
503
+ print(f"Effective batch size: {cfg.per_device_batch_size * cfg.gradient_accumulation_steps}")
504
+ print(f"Total training steps: {len(train_dataset) // (cfg.per_device_batch_size * cfg.gradient_accumulation_steps) * cfg.num_train_epochs}")
505
+
506
+ try:
507
+ trainer.train()
508
+ except Exception as exc:
509
+ raise RuntimeError(f"Training failed: {exc}") from exc
510
+
511
+ try:
512
+ trainer.save_model(cfg.output_dir)
513
+ print(f"Model saved to {cfg.output_dir}")
514
+ except Exception as exc:
515
+ raise RuntimeError(f"Failed to save model to {cfg.output_dir}: {exc}") from exc
516
+
@@ -0,0 +1,89 @@
1
+ """Public dataset ingestion for hybrid LRM training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, Iterable, List, Optional
7
+
8
+ from training.datasets import ReActExample, normalize_text
9
+
10
+
11
+ @dataclass
12
+ class PublicDatasetConfig:
13
+ name: str
14
+ split: str
15
+ prompt_field: str
16
+ response_field: str
17
+ config_name: Optional[str] = None
18
+ prompt_template: Optional[str] = None
19
+ response_template: Optional[str] = None
20
+ system_field: Optional[str] = None
21
+ max_samples: Optional[int] = None
22
+ license_tag: Optional[str] = None
23
+ source_tag: Optional[str] = None
24
+
25
+
26
+ def default_public_configs() -> List[PublicDatasetConfig]:
27
+ """Conservative defaults with known field names."""
28
+ return [
29
+ PublicDatasetConfig(
30
+ name="openai/gsm8k",
31
+ config_name="main",
32
+ split="train",
33
+ prompt_field="question",
34
+ response_field="answer",
35
+ prompt_template="Question: {question}\nAnswer:",
36
+ response_template="{answer}",
37
+ license_tag="unknown",
38
+ source_tag="gsm8k",
39
+ )
40
+ ]
41
+
42
+
43
+ def _format_with_template(template: Optional[str], row: Dict[str, Any], field: str) -> str:
44
+ if template:
45
+ return template.format(**row)
46
+ value = row.get(field, "")
47
+ return str(value) if value is not None else ""
48
+
49
+
50
+ def load_public_examples(
51
+ configs: Iterable[PublicDatasetConfig],
52
+ *,
53
+ seed: int = 7,
54
+ ) -> List[ReActExample]:
55
+ try:
56
+ from datasets import load_dataset # type: ignore
57
+ except Exception as exc:
58
+ raise RuntimeError(f"datasets library is required to load public datasets: {exc}") from exc
59
+
60
+ examples: List[ReActExample] = []
61
+ for cfg in configs:
62
+ if cfg.config_name:
63
+ dataset = load_dataset(cfg.name, cfg.config_name, split=cfg.split)
64
+ else:
65
+ dataset = load_dataset(cfg.name, split=cfg.split)
66
+ rows = list(dataset)
67
+ if cfg.max_samples is not None:
68
+ rows = rows[: cfg.max_samples]
69
+ for row in rows:
70
+ prompt = _format_with_template(cfg.prompt_template, row, cfg.prompt_field)
71
+ response = _format_with_template(cfg.response_template, row, cfg.response_field)
72
+ if cfg.system_field and row.get(cfg.system_field):
73
+ system = str(row[cfg.system_field])
74
+ prompt = f"System: {system}\n{prompt}"
75
+ prompt = normalize_text(prompt)
76
+ response = normalize_text(response)
77
+ if not prompt or not response:
78
+ continue
79
+ tags = {
80
+ "type": "public_reasoning",
81
+ "dataset": cfg.name,
82
+ }
83
+ if cfg.license_tag:
84
+ tags["license"] = cfg.license_tag
85
+ if cfg.source_tag:
86
+ tags["source"] = cfg.source_tag
87
+ examples.append(ReActExample(prompt=prompt, response=response, tags=tags, refusal=False))
88
+ return examples
89
+