crca 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CRCA.py +172 -7
- MODEL_CARD.md +53 -0
- PKG-INFO +8 -2
- RELEASE_NOTES.md +17 -0
- STABILITY.md +19 -0
- architecture/hybrid/consistency_engine.py +362 -0
- architecture/hybrid/conversation_manager.py +421 -0
- architecture/hybrid/explanation_generator.py +452 -0
- architecture/hybrid/few_shot_learner.py +533 -0
- architecture/hybrid/graph_compressor.py +286 -0
- architecture/hybrid/hybrid_agent.py +4398 -0
- architecture/hybrid/language_compiler.py +623 -0
- architecture/hybrid/main,py +0 -0
- architecture/hybrid/reasoning_tracker.py +322 -0
- architecture/hybrid/self_verifier.py +524 -0
- architecture/hybrid/task_decomposer.py +567 -0
- architecture/hybrid/text_corrector.py +341 -0
- benchmark_results/crca_core_benchmarks.json +178 -0
- branches/crca_sd/crca_sd_realtime.py +6 -2
- branches/general_agent/__init__.py +102 -0
- branches/general_agent/general_agent.py +1400 -0
- branches/general_agent/personality.py +169 -0
- branches/general_agent/utils/__init__.py +19 -0
- branches/general_agent/utils/prompt_builder.py +170 -0
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/METADATA +8 -2
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/RECORD +303 -20
- crca_core/__init__.py +35 -0
- crca_core/benchmarks/__init__.py +14 -0
- crca_core/benchmarks/synthetic_scm.py +103 -0
- crca_core/core/__init__.py +23 -0
- crca_core/core/api.py +120 -0
- crca_core/core/estimate.py +208 -0
- crca_core/core/godclass.py +72 -0
- crca_core/core/intervention_design.py +174 -0
- crca_core/core/lifecycle.py +48 -0
- crca_core/discovery/__init__.py +9 -0
- crca_core/discovery/tabular.py +193 -0
- crca_core/identify/__init__.py +171 -0
- crca_core/identify/backdoor.py +39 -0
- crca_core/identify/frontdoor.py +48 -0
- crca_core/identify/graph.py +106 -0
- crca_core/identify/id_algorithm.py +43 -0
- crca_core/identify/iv.py +48 -0
- crca_core/models/__init__.py +67 -0
- crca_core/models/provenance.py +56 -0
- crca_core/models/refusal.py +39 -0
- crca_core/models/result.py +83 -0
- crca_core/models/spec.py +151 -0
- crca_core/models/validation.py +68 -0
- crca_core/scm/__init__.py +9 -0
- crca_core/scm/linear_gaussian.py +198 -0
- crca_core/timeseries/__init__.py +6 -0
- crca_core/timeseries/pcmci.py +181 -0
- crca_llm/__init__.py +12 -0
- crca_llm/client.py +85 -0
- crca_llm/coauthor.py +118 -0
- crca_llm/orchestrator.py +289 -0
- crca_llm/types.py +21 -0
- crca_reasoning/__init__.py +16 -0
- crca_reasoning/critique.py +54 -0
- crca_reasoning/godclass.py +206 -0
- crca_reasoning/memory.py +24 -0
- crca_reasoning/rationale.py +10 -0
- crca_reasoning/react_controller.py +81 -0
- crca_reasoning/tool_router.py +97 -0
- crca_reasoning/types.py +40 -0
- crca_sd/__init__.py +15 -0
- crca_sd/crca_sd_core.py +2 -0
- crca_sd/crca_sd_governance.py +2 -0
- crca_sd/crca_sd_mpc.py +2 -0
- crca_sd/crca_sd_realtime.py +2 -0
- crca_sd/crca_sd_tui.py +2 -0
- cuda-keyring_1.1-1_all.deb +0 -0
- cuda-keyring_1.1-1_all.deb.1 +0 -0
- docs/IMAGE_ANNOTATION_USAGE.md +539 -0
- docs/INSTALL_DEEPSPEED.md +125 -0
- docs/api/branches/crca-cg.md +19 -0
- docs/api/branches/crca-q.md +27 -0
- docs/api/branches/crca-sd.md +37 -0
- docs/api/branches/general-agent.md +24 -0
- docs/api/branches/overview.md +19 -0
- docs/api/crca/agent-methods.md +62 -0
- docs/api/crca/operations.md +79 -0
- docs/api/crca/overview.md +32 -0
- docs/api/image-annotation/engine.md +52 -0
- docs/api/image-annotation/overview.md +17 -0
- docs/api/schemas/annotation.md +34 -0
- docs/api/schemas/core-schemas.md +82 -0
- docs/api/schemas/overview.md +32 -0
- docs/api/schemas/policy.md +30 -0
- docs/api/utils/conversation.md +22 -0
- docs/api/utils/graph-reasoner.md +32 -0
- docs/api/utils/overview.md +21 -0
- docs/api/utils/router.md +19 -0
- docs/api/utils/utilities.md +97 -0
- docs/architecture/causal-graphs.md +41 -0
- docs/architecture/data-flow.md +29 -0
- docs/architecture/design-principles.md +33 -0
- docs/architecture/hybrid-agent/components.md +38 -0
- docs/architecture/hybrid-agent/consistency.md +26 -0
- docs/architecture/hybrid-agent/overview.md +44 -0
- docs/architecture/hybrid-agent/reasoning.md +22 -0
- docs/architecture/llm-integration.md +26 -0
- docs/architecture/modular-structure.md +37 -0
- docs/architecture/overview.md +69 -0
- docs/architecture/policy-engine-arch.md +29 -0
- docs/branches/crca-cg/corposwarm.md +39 -0
- docs/branches/crca-cg/esg-scoring.md +30 -0
- docs/branches/crca-cg/multi-agent.md +35 -0
- docs/branches/crca-cg/overview.md +40 -0
- docs/branches/crca-q/alternative-data.md +55 -0
- docs/branches/crca-q/architecture.md +71 -0
- docs/branches/crca-q/backtesting.md +45 -0
- docs/branches/crca-q/causal-engine.md +33 -0
- docs/branches/crca-q/execution.md +39 -0
- docs/branches/crca-q/market-data.md +60 -0
- docs/branches/crca-q/overview.md +58 -0
- docs/branches/crca-q/philosophy.md +60 -0
- docs/branches/crca-q/portfolio-optimization.md +66 -0
- docs/branches/crca-q/risk-management.md +102 -0
- docs/branches/crca-q/setup.md +65 -0
- docs/branches/crca-q/signal-generation.md +61 -0
- docs/branches/crca-q/signal-validation.md +43 -0
- docs/branches/crca-sd/core.md +84 -0
- docs/branches/crca-sd/governance.md +53 -0
- docs/branches/crca-sd/mpc-solver.md +65 -0
- docs/branches/crca-sd/overview.md +59 -0
- docs/branches/crca-sd/realtime.md +28 -0
- docs/branches/crca-sd/tui.md +20 -0
- docs/branches/general-agent/overview.md +37 -0
- docs/branches/general-agent/personality.md +36 -0
- docs/branches/general-agent/prompt-builder.md +30 -0
- docs/changelog/index.md +79 -0
- docs/contributing/code-style.md +69 -0
- docs/contributing/documentation.md +43 -0
- docs/contributing/overview.md +29 -0
- docs/contributing/testing.md +29 -0
- docs/core/crcagent/async-operations.md +65 -0
- docs/core/crcagent/automatic-extraction.md +107 -0
- docs/core/crcagent/batch-prediction.md +80 -0
- docs/core/crcagent/bayesian-inference.md +60 -0
- docs/core/crcagent/causal-graph.md +92 -0
- docs/core/crcagent/counterfactuals.md +96 -0
- docs/core/crcagent/deterministic-simulation.md +78 -0
- docs/core/crcagent/dual-mode-operation.md +82 -0
- docs/core/crcagent/initialization.md +88 -0
- docs/core/crcagent/optimization.md +65 -0
- docs/core/crcagent/overview.md +63 -0
- docs/core/crcagent/time-series.md +57 -0
- docs/core/schemas/annotation.md +30 -0
- docs/core/schemas/core-schemas.md +82 -0
- docs/core/schemas/overview.md +30 -0
- docs/core/schemas/policy.md +41 -0
- docs/core/templates/base-agent.md +31 -0
- docs/core/templates/feature-mixins.md +31 -0
- docs/core/templates/overview.md +29 -0
- docs/core/templates/templates-guide.md +75 -0
- docs/core/tools/mcp-client.md +34 -0
- docs/core/tools/overview.md +24 -0
- docs/core/utils/conversation.md +27 -0
- docs/core/utils/graph-reasoner.md +29 -0
- docs/core/utils/overview.md +27 -0
- docs/core/utils/router.md +27 -0
- docs/core/utils/utilities.md +97 -0
- docs/css/custom.css +84 -0
- docs/examples/basic-usage.md +57 -0
- docs/examples/general-agent/general-agent-examples.md +50 -0
- docs/examples/hybrid-agent/hybrid-agent-examples.md +56 -0
- docs/examples/image-annotation/image-annotation-examples.md +54 -0
- docs/examples/integration/integration-examples.md +58 -0
- docs/examples/overview.md +37 -0
- docs/examples/trading/trading-examples.md +46 -0
- docs/features/causal-reasoning/advanced-topics.md +101 -0
- docs/features/causal-reasoning/counterfactuals.md +43 -0
- docs/features/causal-reasoning/do-calculus.md +50 -0
- docs/features/causal-reasoning/overview.md +47 -0
- docs/features/causal-reasoning/structural-models.md +52 -0
- docs/features/hybrid-agent/advanced-components.md +55 -0
- docs/features/hybrid-agent/core-components.md +64 -0
- docs/features/hybrid-agent/overview.md +34 -0
- docs/features/image-annotation/engine.md +82 -0
- docs/features/image-annotation/features.md +113 -0
- docs/features/image-annotation/integration.md +75 -0
- docs/features/image-annotation/overview.md +53 -0
- docs/features/image-annotation/quickstart.md +73 -0
- docs/features/policy-engine/doctrine-ledger.md +105 -0
- docs/features/policy-engine/monitoring.md +44 -0
- docs/features/policy-engine/mpc-control.md +89 -0
- docs/features/policy-engine/overview.md +46 -0
- docs/getting-started/configuration.md +225 -0
- docs/getting-started/first-agent.md +164 -0
- docs/getting-started/installation.md +144 -0
- docs/getting-started/quickstart.md +137 -0
- docs/index.md +118 -0
- docs/js/mathjax.js +13 -0
- docs/lrm/discovery_proof_notes.md +25 -0
- docs/lrm/finetune_full.md +83 -0
- docs/lrm/math_appendix.md +120 -0
- docs/lrm/overview.md +32 -0
- docs/mkdocs.yml +238 -0
- docs/stylesheets/extra.css +21 -0
- docs_generated/crca_core/CounterfactualResult.md +12 -0
- docs_generated/crca_core/DiscoveryHypothesisResult.md +13 -0
- docs_generated/crca_core/DraftSpec.md +13 -0
- docs_generated/crca_core/EstimateResult.md +13 -0
- docs_generated/crca_core/IdentificationResult.md +17 -0
- docs_generated/crca_core/InterventionDesignResult.md +12 -0
- docs_generated/crca_core/LockedSpec.md +15 -0
- docs_generated/crca_core/RefusalResult.md +12 -0
- docs_generated/crca_core/ValidationReport.md +9 -0
- docs_generated/crca_core/index.md +13 -0
- examples/general_agent_example.py +277 -0
- examples/general_agent_quickstart.py +202 -0
- examples/general_agent_simple.py +92 -0
- examples/hybrid_agent_auto_extraction.py +84 -0
- examples/hybrid_agent_dictionary_demo.py +104 -0
- examples/hybrid_agent_enhanced.py +179 -0
- examples/hybrid_agent_general_knowledge.py +107 -0
- examples/image_annotation_quickstart.py +328 -0
- examples/test_hybrid_fixes.py +77 -0
- image_annotation/__init__.py +27 -0
- image_annotation/annotation_engine.py +2593 -0
- install_cuda_wsl2.sh +59 -0
- install_deepspeed.sh +56 -0
- install_deepspeed_simple.sh +87 -0
- mkdocs.yml +252 -0
- ollama/Modelfile +8 -0
- prompts/__init__.py +2 -1
- prompts/default_crca.py +9 -1
- prompts/general_agent.py +227 -0
- prompts/image_annotation.py +56 -0
- pyproject.toml +17 -2
- requirements-docs.txt +10 -0
- requirements.txt +21 -2
- schemas/__init__.py +26 -1
- schemas/annotation.py +222 -0
- schemas/conversation.py +193 -0
- schemas/hybrid.py +211 -0
- schemas/reasoning.py +276 -0
- schemas_export/crca_core/CounterfactualResult.schema.json +108 -0
- schemas_export/crca_core/DiscoveryHypothesisResult.schema.json +113 -0
- schemas_export/crca_core/DraftSpec.schema.json +635 -0
- schemas_export/crca_core/EstimateResult.schema.json +113 -0
- schemas_export/crca_core/IdentificationResult.schema.json +145 -0
- schemas_export/crca_core/InterventionDesignResult.schema.json +111 -0
- schemas_export/crca_core/LockedSpec.schema.json +646 -0
- schemas_export/crca_core/RefusalResult.schema.json +90 -0
- schemas_export/crca_core/ValidationReport.schema.json +62 -0
- scripts/build_lrm_dataset.py +80 -0
- scripts/export_crca_core_schemas.py +54 -0
- scripts/export_hf_lrm.py +37 -0
- scripts/export_ollama_gguf.py +45 -0
- scripts/generate_changelog.py +157 -0
- scripts/generate_crca_core_docs_from_schemas.py +86 -0
- scripts/run_crca_core_benchmarks.py +163 -0
- scripts/run_full_finetune.py +198 -0
- scripts/run_lrm_eval.py +31 -0
- templates/graph_management.py +29 -0
- tests/conftest.py +9 -0
- tests/test_core.py +2 -3
- tests/test_crca_core_discovery_tabular.py +15 -0
- tests/test_crca_core_estimate_dowhy.py +36 -0
- tests/test_crca_core_identify.py +18 -0
- tests/test_crca_core_intervention_design.py +36 -0
- tests/test_crca_core_linear_gaussian_scm.py +69 -0
- tests/test_crca_core_spec.py +25 -0
- tests/test_crca_core_timeseries_pcmci.py +15 -0
- tests/test_crca_llm_coauthor.py +12 -0
- tests/test_crca_llm_orchestrator.py +80 -0
- tests/test_hybrid_agent_llm_enhanced.py +556 -0
- tests/test_image_annotation_demo.py +376 -0
- tests/test_image_annotation_operational.py +408 -0
- tests/test_image_annotation_unit.py +551 -0
- tests/test_training_moe.py +13 -0
- training/__init__.py +42 -0
- training/datasets.py +140 -0
- training/deepspeed_zero2_0_5b.json +22 -0
- training/deepspeed_zero2_1_5b.json +22 -0
- training/deepspeed_zero3_0_5b.json +28 -0
- training/deepspeed_zero3_14b.json +28 -0
- training/deepspeed_zero3_h100_3gpu.json +20 -0
- training/deepspeed_zero3_offload.json +28 -0
- training/eval.py +92 -0
- training/finetune.py +516 -0
- training/public_datasets.py +89 -0
- training_data/react_train.jsonl +7473 -0
- utils/agent_discovery.py +311 -0
- utils/batch_processor.py +317 -0
- utils/conversation.py +78 -0
- utils/edit_distance.py +118 -0
- utils/formatter.py +33 -0
- utils/graph_reasoner.py +530 -0
- utils/rate_limiter.py +283 -0
- utils/router.py +2 -2
- utils/tool_discovery.py +307 -0
- webui/__init__.py +10 -0
- webui/app.py +229 -0
- webui/config.py +104 -0
- webui/static/css/style.css +332 -0
- webui/static/js/main.js +284 -0
- webui/templates/index.html +42 -0
- tests/test_crca_excel.py +0 -166
- tests/test_data_broker.py +0 -424
- tests/test_palantir.py +0 -349
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/WHEEL +0 -0
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/licenses/LICENSE +0 -0
training/finetune.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
1
|
+
"""Low-compute finetuning pipeline (LoRA/QLoRA when available)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Dict, Optional
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
# Disable DeepSpeed op building if CUDA_HOME not set (prevents MissingCUDAException)
|
|
13
|
+
if "CUDA_HOME" not in os.environ and "DS_BUILD_OPS" not in os.environ:
|
|
14
|
+
os.environ["DS_BUILD_OPS"] = "0"
|
|
15
|
+
|
|
16
|
+
MODEL_REGISTRY: Dict[str, Dict[str, object]] = {
|
|
17
|
+
"google/switch-base-8": {"arch": "seq2seq", "moe": True},
|
|
18
|
+
"google/switch-large-16": {"arch": "seq2seq", "moe": True},
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _resolve_model_info(base_model: str) -> Dict[str, object]:
|
|
23
|
+
model = base_model.lower()
|
|
24
|
+
for model_id, info in MODEL_REGISTRY.items():
|
|
25
|
+
if model == model_id.lower() or model.startswith(model_id.lower()):
|
|
26
|
+
return {"arch": info.get("arch", "causal"), "moe": info.get("moe", False)}
|
|
27
|
+
if "switch" in model:
|
|
28
|
+
return {"arch": "seq2seq", "moe": True}
|
|
29
|
+
return {"arch": "causal", "moe": False}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class FinetuneConfig:
|
|
34
|
+
base_model: str = "microsoft/phi-2"
|
|
35
|
+
output_dir: str = "lrm_finetune_out"
|
|
36
|
+
train_file: str = "training_data/react_train.jsonl"
|
|
37
|
+
eval_file: Optional[str] = None
|
|
38
|
+
num_train_epochs: int = 1
|
|
39
|
+
per_device_batch_size: int = 2
|
|
40
|
+
gradient_accumulation_steps: int = 1
|
|
41
|
+
learning_rate: float = 2e-4
|
|
42
|
+
use_lora: bool = True
|
|
43
|
+
max_seq_length: int = 512
|
|
44
|
+
gradient_checkpointing: bool = False
|
|
45
|
+
fp16: bool = True
|
|
46
|
+
bf16: bool = False
|
|
47
|
+
deepspeed_config: Optional[str] = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def full_finetune_qwen25_1_5b_config() -> FinetuneConfig:
|
|
51
|
+
"""
|
|
52
|
+
Full finetune configuration for Qwen2.5-1.5B-Instruct.
|
|
53
|
+
|
|
54
|
+
Aggressively optimized for CRCA reasoning:
|
|
55
|
+
- Higher learning rate (5e-4) for smaller models to avoid boilerplate
|
|
56
|
+
- Longer sequences (8192) to capture full reasoning chains
|
|
57
|
+
- Full finetune (no LoRA) for maximum reasoning capability
|
|
58
|
+
- DeepSpeed ZeRO-2 for memory efficiency
|
|
59
|
+
- 20 epochs for thorough convergence
|
|
60
|
+
|
|
61
|
+
Compatible with NVIDIA GPUs (e.g. H100 SXM).
|
|
62
|
+
"""
|
|
63
|
+
return FinetuneConfig(
|
|
64
|
+
base_model="Qwen/Qwen2.5-1.5B-Instruct",
|
|
65
|
+
output_dir="lrm_qwen25_1_5b_full_finetune",
|
|
66
|
+
train_file="training_data/react_train.jsonl",
|
|
67
|
+
eval_file=None,
|
|
68
|
+
num_train_epochs=20, # More epochs for CRCA reasoning convergence
|
|
69
|
+
per_device_batch_size=8, # Cloud GPU optimized
|
|
70
|
+
gradient_accumulation_steps=16, # Effective batch size: 128
|
|
71
|
+
learning_rate=5e-4, # Higher LR for smaller models, aggressive for CRCA
|
|
72
|
+
use_lora=False, # Full finetune for maximum reasoning capability
|
|
73
|
+
max_seq_length=8192, # Longer sequences for reasoning chains
|
|
74
|
+
gradient_checkpointing=True, # Enable for memory efficiency
|
|
75
|
+
fp16=True,
|
|
76
|
+
bf16=False,
|
|
77
|
+
deepspeed_config="training/deepspeed_zero2_1_5b.json",
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def full_finetune_qwen25_7b_config() -> FinetuneConfig:
|
|
82
|
+
"""
|
|
83
|
+
Full finetune configuration for Qwen2.5-7B-Instruct.
|
|
84
|
+
|
|
85
|
+
Aggressively optimized for CRCA reasoning:
|
|
86
|
+
- Moderate learning rate (2e-4) for stability with reasoning tasks
|
|
87
|
+
- 20 epochs for thorough convergence on CRCA tasks
|
|
88
|
+
- Full finetune (no LoRA) for maximum reasoning capability
|
|
89
|
+
- DeepSpeed ZeRO-3 with CPU offload for memory efficiency
|
|
90
|
+
- BF16 for better numerical stability on larger models
|
|
91
|
+
|
|
92
|
+
Compatible with NVIDIA GPUs (e.g. H100 SXM).
|
|
93
|
+
"""
|
|
94
|
+
return FinetuneConfig(
|
|
95
|
+
base_model="Qwen/Qwen2.5-7B-Instruct",
|
|
96
|
+
output_dir="lrm_qwen25_7b_full_finetune",
|
|
97
|
+
train_file="training_data/react_train.jsonl",
|
|
98
|
+
eval_file=None,
|
|
99
|
+
num_train_epochs=20, # Increased from 1 for CRCA reasoning convergence
|
|
100
|
+
per_device_batch_size=4, # Cloud GPU optimized
|
|
101
|
+
gradient_accumulation_steps=32, # Effective batch size: 128
|
|
102
|
+
learning_rate=2e-4, # Optimized for CRCA reasoning, higher than default
|
|
103
|
+
use_lora=False, # Full finetune for maximum reasoning capability
|
|
104
|
+
max_seq_length=4096, # Full context for reasoning chains
|
|
105
|
+
gradient_checkpointing=True, # Enable for memory efficiency
|
|
106
|
+
fp16=False,
|
|
107
|
+
bf16=True, # Better numerical stability for larger models
|
|
108
|
+
deepspeed_config="training/deepspeed_zero3_offload.json",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def full_finetune_qwen25_14b_config() -> FinetuneConfig:
|
|
113
|
+
"""
|
|
114
|
+
Full finetune configuration for Qwen2.5-14B-Instruct.
|
|
115
|
+
|
|
116
|
+
Aggressively optimized for CRCA reasoning:
|
|
117
|
+
- Lower learning rate (1e-4) for stability on large models
|
|
118
|
+
- 20 epochs for thorough convergence on CRCA tasks
|
|
119
|
+
- Full finetune (no LoRA) for maximum reasoning capability
|
|
120
|
+
- DeepSpeed ZeRO-3 with CPU offload for memory efficiency
|
|
121
|
+
- BF16 required for numerical stability
|
|
122
|
+
- Longer gradient accumulation for effective batch size
|
|
123
|
+
|
|
124
|
+
Compatible with NVIDIA GPUs (e.g. H100 SXM).
|
|
125
|
+
"""
|
|
126
|
+
return FinetuneConfig(
|
|
127
|
+
base_model="Qwen/Qwen2.5-14B-Instruct",
|
|
128
|
+
output_dir="lrm_qwen25_14b_full_finetune",
|
|
129
|
+
train_file="training_data/react_train.jsonl",
|
|
130
|
+
eval_file=None,
|
|
131
|
+
num_train_epochs=20, # Thorough convergence for CRCA reasoning
|
|
132
|
+
per_device_batch_size=2, # Cloud GPU optimized (memory constrained)
|
|
133
|
+
gradient_accumulation_steps=64, # Effective batch size: 128
|
|
134
|
+
learning_rate=1e-4, # Lower LR for stability, still aggressive for CRCA
|
|
135
|
+
use_lora=False, # Full finetune for maximum reasoning capability
|
|
136
|
+
max_seq_length=2048, # Memory constraints on 14B model
|
|
137
|
+
gradient_checkpointing=True, # Critical for memory efficiency
|
|
138
|
+
fp16=False,
|
|
139
|
+
bf16=True, # Required for numerical stability on large models
|
|
140
|
+
deepspeed_config="training/deepspeed_zero3_14b.json",
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def full_finetune_switch_base_8_config() -> FinetuneConfig:
|
|
145
|
+
"""
|
|
146
|
+
Full finetune configuration for Switch MoE base (Seq2Seq).
|
|
147
|
+
|
|
148
|
+
Optimized for Switch MoE (encoder-decoder):
|
|
149
|
+
- BF16 for H100 stability
|
|
150
|
+
- ZeRO-3 without CPU offload (H100-class GPUs)
|
|
151
|
+
- Moderate batch sizes for Seq2Seq memory footprint
|
|
152
|
+
"""
|
|
153
|
+
return FinetuneConfig(
|
|
154
|
+
base_model="google/switch-base-8",
|
|
155
|
+
output_dir="lrm_switch_base_8_full_finetune",
|
|
156
|
+
train_file="training_data/react_train.jsonl",
|
|
157
|
+
eval_file=None,
|
|
158
|
+
num_train_epochs=10,
|
|
159
|
+
per_device_batch_size=4,
|
|
160
|
+
gradient_accumulation_steps=16,
|
|
161
|
+
learning_rate=2e-4,
|
|
162
|
+
use_lora=False,
|
|
163
|
+
max_seq_length=1024,
|
|
164
|
+
gradient_checkpointing=True,
|
|
165
|
+
fp16=False,
|
|
166
|
+
bf16=True,
|
|
167
|
+
deepspeed_config="training/deepspeed_zero3_h100_3gpu.json",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def full_finetune_switch_large_16_config() -> FinetuneConfig:
|
|
172
|
+
"""
|
|
173
|
+
Full finetune configuration for Switch MoE large (Seq2Seq).
|
|
174
|
+
|
|
175
|
+
- BF16 for numerical stability
|
|
176
|
+
- ZeRO-3 without CPU offload (H100-class GPUs)
|
|
177
|
+
- Conservative batch sizes to keep memory stable
|
|
178
|
+
"""
|
|
179
|
+
return FinetuneConfig(
|
|
180
|
+
base_model="google/switch-large-16",
|
|
181
|
+
output_dir="lrm_switch_large_16_full_finetune",
|
|
182
|
+
train_file="training_data/react_train.jsonl",
|
|
183
|
+
eval_file=None,
|
|
184
|
+
num_train_epochs=10,
|
|
185
|
+
per_device_batch_size=2,
|
|
186
|
+
gradient_accumulation_steps=32,
|
|
187
|
+
learning_rate=1e-4,
|
|
188
|
+
use_lora=False,
|
|
189
|
+
max_seq_length=1024,
|
|
190
|
+
gradient_checkpointing=True,
|
|
191
|
+
fp16=False,
|
|
192
|
+
bf16=True,
|
|
193
|
+
deepspeed_config="training/deepspeed_zero3_h100_3gpu.json",
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def full_finetune_qwen25_0_5b_config_cloud() -> FinetuneConfig:
|
|
198
|
+
"""
|
|
199
|
+
Cloud GPU optimized configuration for Qwen2.5-0.5B-Instruct.
|
|
200
|
+
|
|
201
|
+
For GPUs with 16GB+ VRAM (RTX 3090, A4000, A100, etc.):
|
|
202
|
+
- Much larger batch sizes
|
|
203
|
+
- Longer sequences
|
|
204
|
+
- Full finetune (no LoRA needed)
|
|
205
|
+
"""
|
|
206
|
+
return FinetuneConfig(
|
|
207
|
+
base_model="Qwen/Qwen2.5-0.5B-Instruct",
|
|
208
|
+
output_dir="lrm_qwen25_0_5b_full_finetune",
|
|
209
|
+
train_file="training_data/react_train.jsonl",
|
|
210
|
+
eval_file=None,
|
|
211
|
+
num_train_epochs=20,
|
|
212
|
+
per_device_batch_size=16, # Cloud GPUs can handle this
|
|
213
|
+
gradient_accumulation_steps=8, # Adjusted for effective batch size
|
|
214
|
+
learning_rate=4e-4,
|
|
215
|
+
use_lora=False, # Full finetune on cloud GPU
|
|
216
|
+
max_seq_length=4096, # Full context length on cloud
|
|
217
|
+
gradient_checkpointing=True,
|
|
218
|
+
fp16=True,
|
|
219
|
+
bf16=False,
|
|
220
|
+
deepspeed_config=None, # Not needed on cloud GPUs
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def full_finetune_qwen25_0_5b_config() -> FinetuneConfig:
|
|
225
|
+
"""
|
|
226
|
+
Full finetune configuration for Qwen2.5-0.5B-Instruct.
|
|
227
|
+
|
|
228
|
+
Optimized for smaller model size:
|
|
229
|
+
- Larger batch sizes (0.5B fits easily in memory)
|
|
230
|
+
- Higher learning rates (smaller models can handle higher LRs)
|
|
231
|
+
- Reduced gradient accumulation (larger batch size means less accumulation needed)
|
|
232
|
+
- Uses Accelerate (simpler than DeepSpeed for 0.5B model)
|
|
233
|
+
"""
|
|
234
|
+
return FinetuneConfig(
|
|
235
|
+
base_model="Qwen/Qwen2.5-0.5B-Instruct",
|
|
236
|
+
output_dir="lrm_qwen25_0_5b_full_finetune",
|
|
237
|
+
train_file="training_data/react_train.jsonl",
|
|
238
|
+
eval_file=None,
|
|
239
|
+
num_train_epochs=20,
|
|
240
|
+
per_device_batch_size=1, # Must be 1 for 4GB GPU
|
|
241
|
+
gradient_accumulation_steps=128, # Large accumulation to maintain effective batch size
|
|
242
|
+
learning_rate=4e-4, # Smaller models can handle higher learning rates
|
|
243
|
+
use_lora=True, # Use LoRA to avoid CPU offload - trains only ~1% of parameters
|
|
244
|
+
max_seq_length=512, # Must be 512 or less for 4GB GPU
|
|
245
|
+
gradient_checkpointing=False, # Not needed with LoRA + 8-bit
|
|
246
|
+
fp16=True,
|
|
247
|
+
bf16=False,
|
|
248
|
+
deepspeed_config=None, # No DeepSpeed needed with LoRA - stays on GPU
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def run_finetune(cfg: FinetuneConfig) -> None:
|
|
253
|
+
"""
|
|
254
|
+
Run finetuning on NVIDIA GPUs (e.g. H100 SXM).
|
|
255
|
+
|
|
256
|
+
Uses CUDA with NCCL for distributed training. Supports 4-bit/8-bit quantization
|
|
257
|
+
for LoRA and full finetune with BF16/FP16. DeepSpeed ZeRO-2/ZeRO-3 for multi-GPU.
|
|
258
|
+
"""
|
|
259
|
+
print("CUDA (NVIDIA GPU) detected - using CUDA settings")
|
|
260
|
+
|
|
261
|
+
# Configure environment for single GPU DeepSpeed (if using DeepSpeed)
|
|
262
|
+
if cfg.deepspeed_config:
|
|
263
|
+
if "RANK" not in os.environ:
|
|
264
|
+
os.environ["RANK"] = "0"
|
|
265
|
+
if "LOCAL_RANK" not in os.environ:
|
|
266
|
+
os.environ["LOCAL_RANK"] = "0"
|
|
267
|
+
if "WORLD_SIZE" not in os.environ:
|
|
268
|
+
os.environ["WORLD_SIZE"] = "1"
|
|
269
|
+
if "MASTER_ADDR" not in os.environ:
|
|
270
|
+
os.environ["MASTER_ADDR"] = "localhost"
|
|
271
|
+
if "MASTER_PORT" not in os.environ:
|
|
272
|
+
os.environ["MASTER_PORT"] = "29500"
|
|
273
|
+
|
|
274
|
+
try:
|
|
275
|
+
from datasets import load_dataset # type: ignore
|
|
276
|
+
from transformers import (
|
|
277
|
+
AutoModelForCausalLM,
|
|
278
|
+
AutoModelForSeq2SeqLM,
|
|
279
|
+
AutoTokenizer,
|
|
280
|
+
DataCollatorForSeq2Seq,
|
|
281
|
+
Trainer,
|
|
282
|
+
TrainingArguments,
|
|
283
|
+
) # type: ignore
|
|
284
|
+
except Exception as exc:
|
|
285
|
+
raise RuntimeError(
|
|
286
|
+
"Missing training dependencies. Install: transformers, datasets, accelerate, peft"
|
|
287
|
+
) from exc
|
|
288
|
+
|
|
289
|
+
model_info = _resolve_model_info(cfg.base_model)
|
|
290
|
+
is_seq2seq = model_info.get("arch") == "seq2seq"
|
|
291
|
+
if model_info.get("moe"):
|
|
292
|
+
print("MoE model detected - using Seq2Seq pipeline")
|
|
293
|
+
|
|
294
|
+
# Load tokenizer with error handling
|
|
295
|
+
try:
|
|
296
|
+
tokenizer = AutoTokenizer.from_pretrained(cfg.base_model, trust_remote_code=True)
|
|
297
|
+
except Exception as exc:
|
|
298
|
+
raise RuntimeError(f"Failed to load tokenizer from {cfg.base_model}: {exc}") from exc
|
|
299
|
+
|
|
300
|
+
if tokenizer.pad_token is None:
|
|
301
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
302
|
+
if tokenizer.pad_token is None:
|
|
303
|
+
raise ValueError(f"Tokenizer from {cfg.base_model} has no pad_token or eos_token")
|
|
304
|
+
|
|
305
|
+
model_cls = AutoModelForSeq2SeqLM if is_seq2seq else AutoModelForCausalLM
|
|
306
|
+
|
|
307
|
+
# Load model (CUDA: 4-bit/8-bit for LoRA, full precision for full finetune)
|
|
308
|
+
if cfg.use_lora:
|
|
309
|
+
try:
|
|
310
|
+
from transformers import BitsAndBytesConfig
|
|
311
|
+
quantization_config = BitsAndBytesConfig(
|
|
312
|
+
load_in_4bit=True,
|
|
313
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
314
|
+
)
|
|
315
|
+
model = model_cls.from_pretrained(
|
|
316
|
+
cfg.base_model,
|
|
317
|
+
quantization_config=quantization_config,
|
|
318
|
+
device_map="auto",
|
|
319
|
+
)
|
|
320
|
+
print("Using 4-bit quantization (CUDA)")
|
|
321
|
+
except (ImportError, Exception):
|
|
322
|
+
try:
|
|
323
|
+
from transformers import BitsAndBytesConfig
|
|
324
|
+
quantization_config = BitsAndBytesConfig(
|
|
325
|
+
load_in_8bit=True,
|
|
326
|
+
bnb_8bit_compute_dtype=torch.float16,
|
|
327
|
+
)
|
|
328
|
+
model = model_cls.from_pretrained(
|
|
329
|
+
cfg.base_model,
|
|
330
|
+
quantization_config=quantization_config,
|
|
331
|
+
device_map="auto",
|
|
332
|
+
)
|
|
333
|
+
print("Using 8-bit quantization (4-bit not available)")
|
|
334
|
+
except (ImportError, Exception):
|
|
335
|
+
model = model_cls.from_pretrained(
|
|
336
|
+
cfg.base_model,
|
|
337
|
+
torch_dtype=torch.bfloat16 if cfg.bf16 else torch.float16,
|
|
338
|
+
low_cpu_mem_usage=True,
|
|
339
|
+
)
|
|
340
|
+
print("Using full precision (quantization not available)")
|
|
341
|
+
else:
|
|
342
|
+
# Full finetune: use BF16/FP16 based on config
|
|
343
|
+
model = model_cls.from_pretrained(
|
|
344
|
+
cfg.base_model,
|
|
345
|
+
torch_dtype=torch.bfloat16 if cfg.bf16 else torch.float16,
|
|
346
|
+
low_cpu_mem_usage=True,
|
|
347
|
+
)
|
|
348
|
+
precision_str = "BF16" if cfg.bf16 else "FP16"
|
|
349
|
+
print(f"Using full finetune with {precision_str} precision")
|
|
350
|
+
|
|
351
|
+
if cfg.use_lora:
|
|
352
|
+
try:
|
|
353
|
+
from peft import LoraConfig, get_peft_model # type: ignore
|
|
354
|
+
except Exception as exc:
|
|
355
|
+
raise RuntimeError("LoRA requested but peft not installed. Install peft.") from exc
|
|
356
|
+
|
|
357
|
+
lora = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
|
|
358
|
+
model = get_peft_model(model, lora)
|
|
359
|
+
# LoRA doesn't need gradient checkpointing - it's already memory efficient
|
|
360
|
+
elif cfg.gradient_checkpointing:
|
|
361
|
+
model.gradient_checkpointing_enable()
|
|
362
|
+
|
|
363
|
+
if not Path(cfg.train_file).exists():
|
|
364
|
+
raise FileNotFoundError(f"Training file not found: {cfg.train_file}")
|
|
365
|
+
if cfg.eval_file and not Path(cfg.eval_file).exists():
|
|
366
|
+
raise FileNotFoundError(f"Eval file not found: {cfg.eval_file}")
|
|
367
|
+
|
|
368
|
+
data_files = {"train": cfg.train_file}
|
|
369
|
+
if cfg.eval_file:
|
|
370
|
+
data_files["validation"] = cfg.eval_file
|
|
371
|
+
|
|
372
|
+
try:
|
|
373
|
+
dataset = load_dataset("json", data_files=data_files)
|
|
374
|
+
except Exception as exc:
|
|
375
|
+
raise RuntimeError(f"Failed to load dataset from {data_files}: {exc}") from exc
|
|
376
|
+
|
|
377
|
+
if "train" not in dataset:
|
|
378
|
+
raise ValueError(f"Dataset must contain 'train' split, got: {list(dataset.keys())}")
|
|
379
|
+
if len(dataset["train"]) == 0:
|
|
380
|
+
raise ValueError("Training dataset is empty")
|
|
381
|
+
|
|
382
|
+
def _tokenize(examples):
|
|
383
|
+
if is_seq2seq:
|
|
384
|
+
inputs = tokenizer(
|
|
385
|
+
examples["prompt"],
|
|
386
|
+
truncation=True,
|
|
387
|
+
padding="max_length",
|
|
388
|
+
max_length=cfg.max_seq_length,
|
|
389
|
+
return_tensors=None,
|
|
390
|
+
)
|
|
391
|
+
targets = tokenizer(
|
|
392
|
+
text_target=examples["response"],
|
|
393
|
+
truncation=True,
|
|
394
|
+
padding="max_length",
|
|
395
|
+
max_length=cfg.max_seq_length,
|
|
396
|
+
return_tensors=None,
|
|
397
|
+
)
|
|
398
|
+
pad_token_id = tokenizer.pad_token_id
|
|
399
|
+
labels = [
|
|
400
|
+
[token_id if token_id != pad_token_id else -100 for token_id in seq]
|
|
401
|
+
for seq in targets["input_ids"]
|
|
402
|
+
]
|
|
403
|
+
inputs["labels"] = labels
|
|
404
|
+
return inputs
|
|
405
|
+
|
|
406
|
+
texts = [p + "\n" + r for p, r in zip(examples["prompt"], examples["response"])]
|
|
407
|
+
tokenized = tokenizer(
|
|
408
|
+
texts,
|
|
409
|
+
truncation=True,
|
|
410
|
+
padding="max_length",
|
|
411
|
+
max_length=cfg.max_seq_length,
|
|
412
|
+
return_tensors=None, # Return lists, not tensors
|
|
413
|
+
)
|
|
414
|
+
# For causal LM, labels are the same as input_ids
|
|
415
|
+
# Set padding tokens to -100 so they're ignored in loss calculation
|
|
416
|
+
labels = []
|
|
417
|
+
pad_token_id = tokenizer.pad_token_id
|
|
418
|
+
for input_ids in tokenized["input_ids"]:
|
|
419
|
+
label = [token_id if token_id != pad_token_id else -100 for token_id in input_ids]
|
|
420
|
+
labels.append(label)
|
|
421
|
+
tokenized["labels"] = labels
|
|
422
|
+
return tokenized
|
|
423
|
+
|
|
424
|
+
# Get column names before tokenization (handle both train and validation)
|
|
425
|
+
original_columns = dataset["train"].column_names
|
|
426
|
+
|
|
427
|
+
tokenized = dataset.map(
|
|
428
|
+
_tokenize,
|
|
429
|
+
batched=True,
|
|
430
|
+
remove_columns=original_columns,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# Validate configuration before training
|
|
434
|
+
if cfg.per_device_batch_size < 1:
|
|
435
|
+
raise ValueError(f"per_device_batch_size must be >= 1, got {cfg.per_device_batch_size}")
|
|
436
|
+
if cfg.gradient_accumulation_steps < 1:
|
|
437
|
+
raise ValueError(f"gradient_accumulation_steps must be >= 1, got {cfg.gradient_accumulation_steps}")
|
|
438
|
+
if cfg.learning_rate <= 0:
|
|
439
|
+
raise ValueError(f"learning_rate must be > 0, got {cfg.learning_rate}")
|
|
440
|
+
if cfg.max_seq_length < 1:
|
|
441
|
+
raise ValueError(f"max_seq_length must be >= 1, got {cfg.max_seq_length}")
|
|
442
|
+
if cfg.fp16 and cfg.bf16:
|
|
443
|
+
raise ValueError("Cannot use both fp16 and bf16 simultaneously")
|
|
444
|
+
|
|
445
|
+
# Resolve DeepSpeed config path if provided
|
|
446
|
+
deepspeed_config_path = None
|
|
447
|
+
if cfg.deepspeed_config:
|
|
448
|
+
deepspeed_config_path = str(Path(cfg.deepspeed_config).resolve())
|
|
449
|
+
if not Path(deepspeed_config_path).exists():
|
|
450
|
+
raise FileNotFoundError(f"DeepSpeed config file not found: {deepspeed_config_path}")
|
|
451
|
+
|
|
452
|
+
args = TrainingArguments(
|
|
453
|
+
output_dir=cfg.output_dir,
|
|
454
|
+
num_train_epochs=cfg.num_train_epochs,
|
|
455
|
+
per_device_train_batch_size=cfg.per_device_batch_size,
|
|
456
|
+
per_device_eval_batch_size=cfg.per_device_batch_size,
|
|
457
|
+
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
|
458
|
+
learning_rate=cfg.learning_rate,
|
|
459
|
+
fp16=cfg.fp16,
|
|
460
|
+
bf16=cfg.bf16,
|
|
461
|
+
gradient_checkpointing=cfg.gradient_checkpointing if not cfg.use_lora else False, # LoRA doesn't need it
|
|
462
|
+
deepspeed=deepspeed_config_path if deepspeed_config_path else None,
|
|
463
|
+
logging_steps=50,
|
|
464
|
+
save_steps=200,
|
|
465
|
+
eval_strategy="no" if cfg.eval_file is None else "steps",
|
|
466
|
+
eval_steps=200 if cfg.eval_file else None, # Evaluate every 200 steps if eval_file provided
|
|
467
|
+
save_total_limit=2,
|
|
468
|
+
remove_unused_columns=False,
|
|
469
|
+
dataloader_pin_memory=False, # Save memory
|
|
470
|
+
dataloader_num_workers=0, # Reduce memory overhead
|
|
471
|
+
optim="adamw_torch", # Use standard AdamW (more memory efficient than fused variants)
|
|
472
|
+
max_grad_norm=1.0, # Gradient clipping
|
|
473
|
+
warmup_steps=100, # Add warmup for better convergence
|
|
474
|
+
lr_scheduler_type="cosine", # Cosine learning rate schedule for better convergence
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
train_dataset = tokenized["train"]
|
|
478
|
+
eval_dataset = tokenized.get("validation") if cfg.eval_file else None
|
|
479
|
+
|
|
480
|
+
if len(train_dataset) == 0:
|
|
481
|
+
raise ValueError("Tokenized training dataset is empty")
|
|
482
|
+
|
|
483
|
+
data_collator = None
|
|
484
|
+
if is_seq2seq:
|
|
485
|
+
data_collator = DataCollatorForSeq2Seq(
|
|
486
|
+
tokenizer=tokenizer,
|
|
487
|
+
model=model,
|
|
488
|
+
label_pad_token_id=-100,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# Data collator is optional when using max_length padding, but helps ensure consistency
|
|
492
|
+
trainer = Trainer(
|
|
493
|
+
model=model,
|
|
494
|
+
args=args,
|
|
495
|
+
train_dataset=train_dataset,
|
|
496
|
+
eval_dataset=eval_dataset,
|
|
497
|
+
data_collator=data_collator,
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
print(f"Starting training with {len(train_dataset)} examples")
|
|
501
|
+
if eval_dataset:
|
|
502
|
+
print(f"Evaluation dataset has {len(eval_dataset)} examples")
|
|
503
|
+
print(f"Effective batch size: {cfg.per_device_batch_size * cfg.gradient_accumulation_steps}")
|
|
504
|
+
print(f"Total training steps: {len(train_dataset) // (cfg.per_device_batch_size * cfg.gradient_accumulation_steps) * cfg.num_train_epochs}")
|
|
505
|
+
|
|
506
|
+
try:
|
|
507
|
+
trainer.train()
|
|
508
|
+
except Exception as exc:
|
|
509
|
+
raise RuntimeError(f"Training failed: {exc}") from exc
|
|
510
|
+
|
|
511
|
+
try:
|
|
512
|
+
trainer.save_model(cfg.output_dir)
|
|
513
|
+
print(f"Model saved to {cfg.output_dir}")
|
|
514
|
+
except Exception as exc:
|
|
515
|
+
raise RuntimeError(f"Failed to save model to {cfg.output_dir}: {exc}") from exc
|
|
516
|
+
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Public dataset ingestion for hybrid LRM training."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Dict, Iterable, List, Optional
|
|
7
|
+
|
|
8
|
+
from training.datasets import ReActExample, normalize_text
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class PublicDatasetConfig:
|
|
13
|
+
name: str
|
|
14
|
+
split: str
|
|
15
|
+
prompt_field: str
|
|
16
|
+
response_field: str
|
|
17
|
+
config_name: Optional[str] = None
|
|
18
|
+
prompt_template: Optional[str] = None
|
|
19
|
+
response_template: Optional[str] = None
|
|
20
|
+
system_field: Optional[str] = None
|
|
21
|
+
max_samples: Optional[int] = None
|
|
22
|
+
license_tag: Optional[str] = None
|
|
23
|
+
source_tag: Optional[str] = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def default_public_configs() -> List[PublicDatasetConfig]:
|
|
27
|
+
"""Conservative defaults with known field names."""
|
|
28
|
+
return [
|
|
29
|
+
PublicDatasetConfig(
|
|
30
|
+
name="openai/gsm8k",
|
|
31
|
+
config_name="main",
|
|
32
|
+
split="train",
|
|
33
|
+
prompt_field="question",
|
|
34
|
+
response_field="answer",
|
|
35
|
+
prompt_template="Question: {question}\nAnswer:",
|
|
36
|
+
response_template="{answer}",
|
|
37
|
+
license_tag="unknown",
|
|
38
|
+
source_tag="gsm8k",
|
|
39
|
+
)
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _format_with_template(template: Optional[str], row: Dict[str, Any], field: str) -> str:
|
|
44
|
+
if template:
|
|
45
|
+
return template.format(**row)
|
|
46
|
+
value = row.get(field, "")
|
|
47
|
+
return str(value) if value is not None else ""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def load_public_examples(
|
|
51
|
+
configs: Iterable[PublicDatasetConfig],
|
|
52
|
+
*,
|
|
53
|
+
seed: int = 7,
|
|
54
|
+
) -> List[ReActExample]:
|
|
55
|
+
try:
|
|
56
|
+
from datasets import load_dataset # type: ignore
|
|
57
|
+
except Exception as exc:
|
|
58
|
+
raise RuntimeError(f"datasets library is required to load public datasets: {exc}") from exc
|
|
59
|
+
|
|
60
|
+
examples: List[ReActExample] = []
|
|
61
|
+
for cfg in configs:
|
|
62
|
+
if cfg.config_name:
|
|
63
|
+
dataset = load_dataset(cfg.name, cfg.config_name, split=cfg.split)
|
|
64
|
+
else:
|
|
65
|
+
dataset = load_dataset(cfg.name, split=cfg.split)
|
|
66
|
+
rows = list(dataset)
|
|
67
|
+
if cfg.max_samples is not None:
|
|
68
|
+
rows = rows[: cfg.max_samples]
|
|
69
|
+
for row in rows:
|
|
70
|
+
prompt = _format_with_template(cfg.prompt_template, row, cfg.prompt_field)
|
|
71
|
+
response = _format_with_template(cfg.response_template, row, cfg.response_field)
|
|
72
|
+
if cfg.system_field and row.get(cfg.system_field):
|
|
73
|
+
system = str(row[cfg.system_field])
|
|
74
|
+
prompt = f"System: {system}\n{prompt}"
|
|
75
|
+
prompt = normalize_text(prompt)
|
|
76
|
+
response = normalize_text(response)
|
|
77
|
+
if not prompt or not response:
|
|
78
|
+
continue
|
|
79
|
+
tags = {
|
|
80
|
+
"type": "public_reasoning",
|
|
81
|
+
"dataset": cfg.name,
|
|
82
|
+
}
|
|
83
|
+
if cfg.license_tag:
|
|
84
|
+
tags["license"] = cfg.license_tag
|
|
85
|
+
if cfg.source_tag:
|
|
86
|
+
tags["source"] = cfg.source_tag
|
|
87
|
+
examples.append(ReActExample(prompt=prompt, response=response, tags=tags, refusal=False))
|
|
88
|
+
return examples
|
|
89
|
+
|