marin-core 0.99__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- marin/__init__.py +2 -0
- marin/cluster/__init__.py +10 -0
- marin/cluster/gcp.py +158 -0
- marin/core/conversation.py +25 -0
- marin/core/data.py +54 -0
- marin/datakit/__init__.py +24 -0
- marin/datakit/canonical/__init__.py +2 -0
- marin/datakit/canonical/fineweb_edu.py +39 -0
- marin/datakit/download/__init__.py +2 -0
- marin/datakit/download/ar5iv.py +156 -0
- marin/datakit/download/bio_chem/__init__.py +22 -0
- marin/datakit/download/bio_chem/_runtime.py +238 -0
- marin/datakit/download/bio_chem/chembl.py +53 -0
- marin/datakit/download/bio_chem/moleculenet.py +59 -0
- marin/datakit/download/bio_chem/pubchem.py +45 -0
- marin/datakit/download/bio_chem/rcsb_pdb.py +116 -0
- marin/datakit/download/bio_chem/refseq.py +49 -0
- marin/datakit/download/bio_chem/rnacentral.py +36 -0
- marin/datakit/download/bio_chem/uniprot.py +54 -0
- marin/datakit/download/biodiversity.py +107 -0
- marin/datakit/download/coderforge.py +127 -0
- marin/datakit/download/common_corpus.py +86 -0
- marin/datakit/download/common_pile.py +126 -0
- marin/datakit/download/davinci_dev.py +265 -0
- marin/datakit/download/diagnostic_logs.py +1158 -0
- marin/datakit/download/dolma.py +41 -0
- marin/datakit/download/dolmino.py +32 -0
- marin/datakit/download/finepdfs.py +62 -0
- marin/datakit/download/formal_methods_evals.py +512 -0
- marin/datakit/download/game_music_evals.py +309 -0
- marin/datakit/download/gh_archive.py +416 -0
- marin/datakit/download/gpt_oss_rollouts.py +100 -0
- marin/datakit/download/hf_simple_util.py +87 -0
- marin/datakit/download/hplt.py +273 -0
- marin/datakit/download/huggingface.py +452 -0
- marin/datakit/download/institutional_books.py +89 -0
- marin/datakit/download/massive.py +803 -0
- marin/datakit/download/molmo2_cap.py +89 -0
- marin/datakit/download/nemotron_terminal.py +86 -0
- marin/datakit/download/nemotron_v1.py +163 -0
- marin/datakit/download/nemotron_v2.py +213 -0
- marin/datakit/download/npm_registry_metadata.py +341 -0
- marin/datakit/download/nsf_awards.py +176 -0
- marin/datakit/download/principia.py +80 -0
- marin/datakit/download/rollout_transforms.py +33 -0
- marin/datakit/download/stackexchange/README.md +20 -0
- marin/datakit/download/stackexchange/stackexchange-urls.tsv +183 -0
- marin/datakit/download/starcoder2_extras.py +59 -0
- marin/datakit/download/superior_reasoning.py +92 -0
- marin/datakit/download/svgfind.py +101 -0
- marin/datakit/download/swe_rebench_openhands.py +95 -0
- marin/datakit/download/synthetic1.py +103 -0
- marin/datakit/download/uncheatable_eval.py +430 -0
- marin/datakit/download/uwf_zeek.py +254 -0
- marin/datakit/download/wikipedia.py +122 -0
- marin/datakit/ingestion_manifest.py +209 -0
- marin/datakit/normalize.py +576 -0
- marin/datakit/sources.py +345 -0
- marin/evaluation/__init__.py +2 -0
- marin/evaluation/evaluation_config.py +119 -0
- marin/evaluation/evaluators/__init__.py +2 -0
- marin/evaluation/evaluators/evalchemy_evaluator.py +1055 -0
- marin/evaluation/evaluators/evaluator.py +51 -0
- marin/evaluation/evaluators/harbor_evaluator.py +676 -0
- marin/evaluation/evaluators/levanter_lm_eval_evaluator.py +120 -0
- marin/evaluation/evaluators/lm_evaluation_harness_evaluator.py +209 -0
- marin/evaluation/evaluators/simple_evaluator.py +171 -0
- marin/evaluation/lm_eval.py +107 -0
- marin/evaluation/log_probs.py +153 -0
- marin/evaluation/perplexity_gap.py +341 -0
- marin/evaluation/run.py +157 -0
- marin/evaluation/save_logprobs.py +262 -0
- marin/evaluation/utils.py +73 -0
- marin/evaluation/visualize.py +74 -0
- marin/execution/__init__.py +31 -0
- marin/execution/__init__.pyi +31 -0
- marin/execution/artifact.py +73 -0
- marin/execution/disk_cache.py +134 -0
- marin/execution/executor.py +1892 -0
- marin/execution/executor_step_status.py +283 -0
- marin/execution/remote.py +137 -0
- marin/execution/step_runner.py +417 -0
- marin/execution/step_spec.py +144 -0
- marin/execution/sweep.py +90 -0
- marin/export/__init__.py +9 -0
- marin/export/hf_upload.py +274 -0
- marin/export/levanter_checkpoint.py +187 -0
- marin/inference/__init__.py +4 -0
- marin/inference/types.py +46 -0
- marin/inference/vllm_server.py +372 -0
- marin/inference/vllm_smoke_test.py +221 -0
- marin/infra/__init__.py +8 -0
- marin/markdown/__init__.py +4 -0
- marin/markdown/markdown.py +998 -0
- marin/mcp/__init__.py +4 -0
- marin/mcp/babysitter.py +848 -0
- marin/processing/__init__.py +2 -0
- marin/processing/classification/README.md +27 -0
- marin/processing/classification/__init__.py +2 -0
- marin/processing/classification/config/quickstart_decontaminate.yaml +8 -0
- marin/processing/classification/config/quickstart_dedupe.yaml +4 -0
- marin/processing/classification/consolidate.py +198 -0
- marin/processing/classification/decon.py +369 -0
- marin/processing/classification/deduplication/__init__.py +2 -0
- marin/processing/classification/deduplication/connected_components.py +327 -0
- marin/processing/classification/deduplication/dedup_commons.py +205 -0
- marin/processing/classification/deduplication/exact.py +260 -0
- marin/processing/classification/deduplication/fuzzy_dups.py +386 -0
- marin/processing/classification/deduplication/fuzzy_minhash.py +233 -0
- marin/processing/tokenize/__init__.py +17 -0
- marin/processing/tokenize/data_configs.py +399 -0
- marin/processing/tokenize/download_pretokenized.py +153 -0
- marin/processing/tokenize/tokenize.py +545 -0
- marin/profiling/__init__.py +45 -0
- marin/profiling/cli.py +469 -0
- marin/profiling/compare_bundle.py +105 -0
- marin/profiling/ingest.py +1849 -0
- marin/profiling/publish.py +128 -0
- marin/profiling/query.py +534 -0
- marin/profiling/report.py +241 -0
- marin/profiling/schema.py +545 -0
- marin/profiling/semantics.py +147 -0
- marin/profiling/tracking.py +215 -0
- marin/py.typed +0 -0
- marin/rl/curriculum.py +608 -0
- marin/rl/environments/__init__.py +6 -0
- marin/rl/environments/base.py +85 -0
- marin/rl/environments/inference_ctx/__init__.py +25 -0
- marin/rl/environments/inference_ctx/async_vllm.py +52 -0
- marin/rl/environments/inference_ctx/base.py +134 -0
- marin/rl/environments/inference_ctx/inflight/async_bridge.py +102 -0
- marin/rl/environments/inference_ctx/inflight/worker.py +280 -0
- marin/rl/environments/inference_ctx/levanter.py +146 -0
- marin/rl/environments/inference_ctx/render.py +373 -0
- marin/rl/environments/inference_ctx/staging.py +75 -0
- marin/rl/environments/inference_ctx/vllm.py +394 -0
- marin/rl/environments/inference_ctx/vllm_utils.py +146 -0
- marin/rl/environments/math_env.py +306 -0
- marin/rl/environments/mock_env.py +362 -0
- marin/rl/environments/prime_intellect_env.py +196 -0
- marin/rl/environments/process_vllm_results.py +180 -0
- marin/rl/environments/tinker_environments/math_env.py +140 -0
- marin/rl/environments/tinker_environments/math_grading.py +598 -0
- marin/rl/math_utils.py +57 -0
- marin/rl/metrics.py +24 -0
- marin/rl/model_utils.py +122 -0
- marin/rl/orchestration.py +339 -0
- marin/rl/placement.py +72 -0
- marin/rl/replay_buffer.py +407 -0
- marin/rl/rl_experiment_utils.py +367 -0
- marin/rl/rl_job.py +320 -0
- marin/rl/rl_losses.py +473 -0
- marin/rl/rollout_storage.py +450 -0
- marin/rl/rollout_worker.py +1106 -0
- marin/rl/run_state.py +111 -0
- marin/rl/runtime.py +35 -0
- marin/rl/scripts/export_env_prompts.py +142 -0
- marin/rl/scripts/replay_completions.py +279 -0
- marin/rl/scripts/view_rollout.py +121 -0
- marin/rl/scripts/visualize_curriculum.ipynb +355 -0
- marin/rl/train_batch.py +168 -0
- marin/rl/train_worker.py +651 -0
- marin/rl/types.py +123 -0
- marin/rl/weight_transfer/__init__.py +99 -0
- marin/rl/weight_transfer/arrow_flight.py +754 -0
- marin/rl/weight_transfer/base.py +151 -0
- marin/rl/weight_transfer/checkpoint.py +192 -0
- marin/rl/weight_utils.py +135 -0
- marin/run/__init__.py +2 -0
- marin/run/slurm_run.py +309 -0
- marin/scaling_laws/__init__.py +60 -0
- marin/scaling_laws/eval_metrics_reader.py +120 -0
- marin/scaling_laws/isoflop_analysis.py +423 -0
- marin/scaling_laws/scaling_plots.py +317 -0
- marin/scaling_laws/tpu_utils.py +79 -0
- marin/schemas/web/convert.py +72 -0
- marin/schemas/web/selectors.py +42 -0
- marin/tokenize/slice_cache.py +231 -0
- marin/training/__init__.py +19 -0
- marin/training/run_environment.py +45 -0
- marin/training/training.py +446 -0
- marin/transform/ar5iv/transform.py +280 -0
- marin/transform/ar5iv/transform_ar5iv.py +175 -0
- marin/transform/bio_chem/__init__.py +21 -0
- marin/transform/bio_chem/splitters.py +276 -0
- marin/transform/common_pile/filter_by_extension.py +117 -0
- marin/transform/conversation/adapters.py +159 -0
- marin/transform/conversation/conversation_to_dolma.py +54 -0
- marin/transform/conversation/preference_data_adapters.py +73 -0
- marin/transform/conversation/transform_conversation.py +416 -0
- marin/transform/conversation/transform_preference_data.py +296 -0
- marin/transform/dolmino/filter_dolmino.py +61 -0
- marin/transform/evaluation/eval_to_dolma.py +48 -0
- marin/transform/evaluation/raw_lm_eval.py +372 -0
- marin/transform/huggingface/dataset_to_eval.py +590 -0
- marin/transform/huggingface/raw_text.py +295 -0
- marin/transform/lingoly/to_dolma.py +87 -0
- marin/transform/medical/lavita_to_dolma.py +147 -0
- marin/transform/security_artifacts/__init__.py +31 -0
- marin/transform/security_artifacts/renderers.py +285 -0
- marin/transform/security_artifacts/zeek_to_dolma.py +213 -0
- marin/transform/simple_html_to_md/process.py +84 -0
- marin/transform/stackexchange/filter_stackexchange.py +78 -0
- marin/transform/stackexchange/transform_stackexchange.py +151 -0
- marin/transform/structured_text/__init__.py +16 -0
- marin/transform/structured_text/table_records.py +487 -0
- marin/transform/structured_text/web_data_commons.py +374 -0
- marin/transform/wikipedia/transform_wikipedia.py +298 -0
- marin/utilities/__init__.py +2 -0
- marin/utilities/dataclass_utils.py +13 -0
- marin/utilities/executor_utils.py +83 -0
- marin/utilities/json_encoder.py +33 -0
- marin/utilities/upload_gcs_to_hf.py +364 -0
- marin/utilities/validation_utils.py +73 -0
- marin/utilities/wandb_utils.py +55 -0
- marin/utils.py +220 -0
- marin/validate/validate.py +163 -0
- marin/web/__init__.py +4 -0
- marin/web/convert.py +114 -0
- marin_core-0.99.dist-info/METADATA +293 -0
- marin_core-0.99.dist-info/RECORD +223 -0
- marin_core-0.99.dist-info/WHEEL +4 -0
- marin_core-0.99.dist-info/entry_points.txt +2 -0
marin/__init__.py
ADDED
marin/cluster/gcp.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
# Copyright The Marin Authors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""GCP utilities for cluster management.
|
|
5
|
+
|
|
6
|
+
This provides functions to get access to the current GCP configuration, list and
|
|
7
|
+
connect to TPUs, and find TPUs by IP address.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import subprocess
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def run_gcloud_command(cmd: list[str], **kwargs) -> subprocess.CompletedProcess:
|
|
19
|
+
"""Run a gcloud command with error handling."""
|
|
20
|
+
try:
|
|
21
|
+
logger.info(f"Running {' '.join(cmd)}")
|
|
22
|
+
return subprocess.run(cmd, check=True, capture_output=True, text=True, **kwargs)
|
|
23
|
+
except subprocess.CalledProcessError as e:
|
|
24
|
+
raise RuntimeError(f"gcloud command failed: {' '.join(cmd)}\nError: {e.stderr}") from e
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_project_id() -> str | None:
|
|
28
|
+
"""Get the current GCP project ID."""
|
|
29
|
+
try:
|
|
30
|
+
result = run_gcloud_command(["gcloud", "config", "get-value", "project"])
|
|
31
|
+
return result.stdout.strip() or None
|
|
32
|
+
except RuntimeError:
|
|
33
|
+
return None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def list_tpu_nodes(project: str, zone: str, filter_expr: str = "") -> list[dict[str, Any]]:
|
|
37
|
+
"""List TPU nodes in a zone."""
|
|
38
|
+
cmd = [
|
|
39
|
+
"gcloud",
|
|
40
|
+
"compute",
|
|
41
|
+
"tpus",
|
|
42
|
+
"tpu-vm",
|
|
43
|
+
"list",
|
|
44
|
+
f"--project={project}",
|
|
45
|
+
f"--zone={zone}",
|
|
46
|
+
"--format=json",
|
|
47
|
+
]
|
|
48
|
+
if filter_expr:
|
|
49
|
+
cmd.append(f"--filter={filter_expr}")
|
|
50
|
+
|
|
51
|
+
result = run_gcloud_command(cmd)
|
|
52
|
+
return json.loads(result.stdout)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def find_tpu_by_ip(target_ip: str, project: str, zone: str = "-") -> tuple[str, str, int] | None:
|
|
56
|
+
"""Find TPU node by its internal IP address.
|
|
57
|
+
|
|
58
|
+
Searches all zones by default (zone="-").
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Tuple of (tpu_name, zone, worker_index) or None if not found
|
|
62
|
+
"""
|
|
63
|
+
tpu_nodes = list_tpu_nodes(project, zone)
|
|
64
|
+
|
|
65
|
+
for node in tpu_nodes:
|
|
66
|
+
network_endpoints = node.get("networkEndpoints", [])
|
|
67
|
+
for worker_index, endpoint in enumerate(network_endpoints):
|
|
68
|
+
if endpoint.get("ipAddress") == target_ip:
|
|
69
|
+
# Extract simple name from full resource path
|
|
70
|
+
full_name = node["name"]
|
|
71
|
+
name_parts = full_name.split("/")
|
|
72
|
+
if len(name_parts) >= 6:
|
|
73
|
+
simple_name = name_parts[5] # nodes/simple-name
|
|
74
|
+
node_zone = name_parts[3] # locations/zone
|
|
75
|
+
return simple_name, node_zone, worker_index
|
|
76
|
+
else:
|
|
77
|
+
# Fallback for different naming schemes
|
|
78
|
+
return full_name, zone, worker_index
|
|
79
|
+
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def find_vm_by_ip(target_ip: str, project: str) -> tuple[str, str] | None:
|
|
84
|
+
"""Find a GCE VM by its internal IP address.
|
|
85
|
+
|
|
86
|
+
Searches all zones.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Tuple of (instance_name, zone) or None if not found
|
|
90
|
+
"""
|
|
91
|
+
cmd = [
|
|
92
|
+
"gcloud",
|
|
93
|
+
"compute",
|
|
94
|
+
"instances",
|
|
95
|
+
"list",
|
|
96
|
+
f"--project={project}",
|
|
97
|
+
f"--filter=networkInterfaces[0].networkIP={target_ip}",
|
|
98
|
+
"--format=json(name,zone)",
|
|
99
|
+
]
|
|
100
|
+
result = run_gcloud_command(cmd)
|
|
101
|
+
instances = json.loads(result.stdout)
|
|
102
|
+
if not instances:
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
if len(instances) > 1:
|
|
106
|
+
details = ", ".join(f"{i['name']} ({i['zone'].split('/')[-1]})" for i in instances)
|
|
107
|
+
raise RuntimeError(f"Multiple VMs found with IP {target_ip}: {details}")
|
|
108
|
+
|
|
109
|
+
instance = instances[0]
|
|
110
|
+
name = instance["name"]
|
|
111
|
+
# zone is a full URL like .../zones/us-central1-a
|
|
112
|
+
zone = instance["zone"].split("/")[-1]
|
|
113
|
+
return name, zone
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def ssh_to_vm(instance_name: str, zone: str, project: str, extra_args: list[str] | None = None) -> None:
|
|
117
|
+
"""SSH into a GCE VM."""
|
|
118
|
+
cmd = [
|
|
119
|
+
"gcloud",
|
|
120
|
+
"compute",
|
|
121
|
+
"ssh",
|
|
122
|
+
instance_name,
|
|
123
|
+
f"--zone={zone}",
|
|
124
|
+
f"--project={project}",
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
if extra_args:
|
|
128
|
+
cmd.extend(["--", *extra_args])
|
|
129
|
+
|
|
130
|
+
subprocess.run(cmd, check=True)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def ssh_to_tpu(tpu_name: str, zone: str, project: str, extra_args: list[str] | None = None, worker_id: int = 0) -> None:
|
|
134
|
+
"""SSH into a TPU node.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
tpu_name: Name of the TPU
|
|
138
|
+
zone: GCP zone
|
|
139
|
+
project: GCP project
|
|
140
|
+
extra_args: Additional SSH arguments
|
|
141
|
+
worker_id: Worker index for multi-worker TPUs (default: 0)
|
|
142
|
+
"""
|
|
143
|
+
cmd = [
|
|
144
|
+
"gcloud",
|
|
145
|
+
"compute",
|
|
146
|
+
"tpus",
|
|
147
|
+
"tpu-vm",
|
|
148
|
+
"ssh",
|
|
149
|
+
tpu_name,
|
|
150
|
+
f"--zone={zone}",
|
|
151
|
+
f"--project={project}",
|
|
152
|
+
f"--worker={worker_id}",
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
if extra_args:
|
|
156
|
+
cmd.extend(["--", *extra_args])
|
|
157
|
+
|
|
158
|
+
subprocess.run(cmd, check=True)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright The Marin Authors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OpenAIChatMessage(BaseModel):
|
|
10
|
+
model_config = ConfigDict(extra="allow")
|
|
11
|
+
role: str
|
|
12
|
+
content: Any
|
|
13
|
+
name: str | None = None
|
|
14
|
+
tool_calls: list[dict[str, Any]] | None = None
|
|
15
|
+
tool_call_id: str | None = Field(default=None, alias="tool_call_id")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DolmaConversationOutput(BaseModel):
|
|
19
|
+
model_config = ConfigDict(extra="allow")
|
|
20
|
+
id: str
|
|
21
|
+
source: str
|
|
22
|
+
messages: list[OpenAIChatMessage]
|
|
23
|
+
added: str
|
|
24
|
+
created: str
|
|
25
|
+
metadata: dict[str, Any]
|
marin/core/data.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright The Marin Authors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class QAExampleMetadata:
|
|
9
|
+
"""
|
|
10
|
+
Dataclass representing metadata for a document.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
subset (str): subset of dataset
|
|
14
|
+
split (str): split of dataset
|
|
15
|
+
revision (str): revision of dataset
|
|
16
|
+
provenance (str): URL of source of data, usually HF
|
|
17
|
+
answer (str): text of answer
|
|
18
|
+
answer_idx (str): index into list of answer options corresponding to correct answer
|
|
19
|
+
answer_label (str): label of correct answer (e.g. A)
|
|
20
|
+
options (list[str]): list of potential options for multiple choice question
|
|
21
|
+
answer_labels (list[str]): list of labels for options (e.g. A,B,C,D)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
subset: str | None = None
|
|
25
|
+
split: str | None = None
|
|
26
|
+
revision: str | None = None
|
|
27
|
+
provenance: str | None = None
|
|
28
|
+
answer: str | None = None
|
|
29
|
+
answer_idx: int | None = None
|
|
30
|
+
answer_label: str | None = None
|
|
31
|
+
options: list[str] | None = None
|
|
32
|
+
answer_labels: list[str] | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class QAExample:
|
|
37
|
+
"""
|
|
38
|
+
Dataclass representing a document
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
id (str): Unique identifier for the record.
|
|
42
|
+
source (str): The name of the dataset.
|
|
43
|
+
metadata (QAExampleMetadata): Metadata related to the dataset.
|
|
44
|
+
text (str): The text of the document
|
|
45
|
+
prompt (str): If document is prompt/response, the prompt component
|
|
46
|
+
response (str): If document is prompt/response, the expected response component
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
id: str
|
|
50
|
+
source: str
|
|
51
|
+
metadata: QAExampleMetadata
|
|
52
|
+
text: str | None = None
|
|
53
|
+
prompt: str | None = None
|
|
54
|
+
response: str | None = None
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright The Marin Authors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Datakit: composable pipeline stages with a standard Parquet format.
|
|
5
|
+
|
|
6
|
+
The standard format pins three mandatory columns on every normalized record:
|
|
7
|
+
``id`` (deterministic content hash), ``text`` (UTF-8 primary content), and
|
|
8
|
+
``partition_id`` (int, the output shard the row was written to at normalize
|
|
9
|
+
time). The shard count itself lives on the artifact, not the row.
|
|
10
|
+
|
|
11
|
+
Downstream stages preserve ``partition_id`` and use it as the ``group_by`` key
|
|
12
|
+
when a global shuffle (e.g. cross-document dedup) needs to land output back
|
|
13
|
+
co-partitioned with the source.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def partition_filename(partition_id: int, num_partitions: int) -> str:
|
|
18
|
+
"""Return the standard datakit partition filename for the given index.
|
|
19
|
+
|
|
20
|
+
Datakit shards follow ``part-NNNNN-of-MMMMM.parquet`` naming. Routing
|
|
21
|
+
output through this helper keeps shuffler-written attribute files
|
|
22
|
+
discoverable by consolidate's filename-based join.
|
|
23
|
+
"""
|
|
24
|
+
return f"part-{partition_id:05d}-of-{num_partitions:05d}.parquet"
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright The Marin Authors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Datakit canonical pipeline for FineWeb-Edu.
|
|
5
|
+
|
|
6
|
+
HuggingFace: HuggingFaceFW/fineweb-edu
|
|
7
|
+
|
|
8
|
+
FineWeb-Edu is a filtered subset of FineWeb selected for educational content.
|
|
9
|
+
The raw download is Parquet with columns: text, id, url, dump, file_path,
|
|
10
|
+
language, language_score, token_count, score, int_score.
|
|
11
|
+
|
|
12
|
+
Subsets available on HuggingFace:
|
|
13
|
+
- data/ — full dataset
|
|
14
|
+
- sample/10BT — 10B token sample
|
|
15
|
+
- sample/100BT — 100B token sample
|
|
16
|
+
- sample/350BT — 350B token sample
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from fray import ResourceConfig
|
|
20
|
+
|
|
21
|
+
from marin.datakit.download.huggingface import download_hf_step
|
|
22
|
+
from marin.execution.step_spec import StepSpec
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def download(
|
|
26
|
+
*,
|
|
27
|
+
revision: str = "87f0914",
|
|
28
|
+
hf_urls_glob: list[str] | None = None,
|
|
29
|
+
worker_resources: ResourceConfig | None = None,
|
|
30
|
+
) -> StepSpec:
|
|
31
|
+
"""Download FineWeb-Edu from HuggingFace."""
|
|
32
|
+
return download_hf_step(
|
|
33
|
+
"raw/fineweb-edu",
|
|
34
|
+
hf_dataset_id="HuggingFaceFW/fineweb-edu",
|
|
35
|
+
revision=revision,
|
|
36
|
+
hf_urls_glob=hf_urls_glob,
|
|
37
|
+
override_output_path=f"raw/fineweb-edu-{revision}",
|
|
38
|
+
worker_resources=worker_resources,
|
|
39
|
+
)
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
# Copyright The Marin Authors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Download and process Ar5iv dataset from a zip file.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import zipfile
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
|
|
15
|
+
import draccus
|
|
16
|
+
from rigging.filesystem import open_url
|
|
17
|
+
from rigging.log_setup import configure_logging
|
|
18
|
+
from zephyr import Dataset, ZephyrContext
|
|
19
|
+
from zephyr.writers import atomic_rename
|
|
20
|
+
|
|
21
|
+
from marin.execution.step_spec import StepSpec
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class Ar5ivDownloadConfig:
|
|
28
|
+
input_path: str
|
|
29
|
+
output_path: str
|
|
30
|
+
max_files: int | None = None # Maximum number of shards to process
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def process_shard(shard_task: dict) -> dict:
|
|
34
|
+
"""
|
|
35
|
+
Process a single shard by extracting its files from the zip in GCS and uploading the merged JSONL.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
shard_task: Dict with keys 'input_path', 'output_path', 'shard_id', 'file_list'
|
|
39
|
+
"""
|
|
40
|
+
input_path = shard_task["input_path"]
|
|
41
|
+
output_path = shard_task["output_path"]
|
|
42
|
+
shard_id = shard_task["shard_id"]
|
|
43
|
+
file_list = shard_task["file_list"]
|
|
44
|
+
gcs_path = f"{output_path}/{shard_id}.jsonl.gz"
|
|
45
|
+
|
|
46
|
+
with open_url(str(input_path), "rb") as f:
|
|
47
|
+
with zipfile.ZipFile(f) as zf:
|
|
48
|
+
with atomic_rename(gcs_path) as temp_path, open_url(temp_path, "wt", compression="gzip") as out_f:
|
|
49
|
+
for filename in file_list:
|
|
50
|
+
with zf.open(filename, "r") as file_handle:
|
|
51
|
+
content = file_handle.read()
|
|
52
|
+
record = {
|
|
53
|
+
"filename": filename,
|
|
54
|
+
"format": "html",
|
|
55
|
+
"content": content.decode("utf-8", errors="replace"),
|
|
56
|
+
}
|
|
57
|
+
print(json.dumps(record), file=out_f)
|
|
58
|
+
|
|
59
|
+
logger.info(f"Shard {shard_id} with {len(file_list)} files uploaded to {gcs_path}")
|
|
60
|
+
return {"shard_id": shard_id, "num_files": len(file_list), "output_path": gcs_path}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def download(cfg: Ar5ivDownloadConfig) -> None:
|
|
64
|
+
"""
|
|
65
|
+
Download and process Ar5iv dataset from a zip file in GCS.
|
|
66
|
+
|
|
67
|
+
This function can be called by the executor framework or used standalone.
|
|
68
|
+
"""
|
|
69
|
+
logger.info("Starting transfer of Ar5iv dataset...")
|
|
70
|
+
logger.info(f"Source: {cfg.input_path}")
|
|
71
|
+
|
|
72
|
+
# Use fsspec+zipfile to list all files
|
|
73
|
+
with open_url(str(cfg.input_path), "rb") as f:
|
|
74
|
+
with zipfile.ZipFile(f) as zf:
|
|
75
|
+
all_files = zf.infolist()
|
|
76
|
+
|
|
77
|
+
# Group by shard directory
|
|
78
|
+
# We assume structure: something like: shard_id/.../file
|
|
79
|
+
# shard_id is derived from the second last component if files are nested.
|
|
80
|
+
# Adjust as needed if directory structure differs.
|
|
81
|
+
shard_dict = defaultdict(list)
|
|
82
|
+
for info in all_files:
|
|
83
|
+
if info.is_dir():
|
|
84
|
+
continue
|
|
85
|
+
# E.g. path might look like: "003/something.html"
|
|
86
|
+
# Extract shard_id from the directory:
|
|
87
|
+
# Split by "/" and take the first part if we assume structure {shard_id}/file
|
|
88
|
+
parts = info.filename.strip("/").split("/")
|
|
89
|
+
if len(parts) < 2:
|
|
90
|
+
# File at root level - decide how to handle this case.
|
|
91
|
+
# If no directory structure is given, skip or treat differently.
|
|
92
|
+
continue
|
|
93
|
+
shard_id = parts[-2] # get the second-last directory as shard_id
|
|
94
|
+
shard_dict[shard_id].append(info.filename)
|
|
95
|
+
|
|
96
|
+
# Apply max_files limit if provided
|
|
97
|
+
shard_ids = list(shard_dict.keys())
|
|
98
|
+
if cfg.max_files is not None:
|
|
99
|
+
shard_ids = shard_ids[: cfg.max_files]
|
|
100
|
+
|
|
101
|
+
logger.info(f"Found {len(shard_ids)} shards to process.")
|
|
102
|
+
|
|
103
|
+
# Build task list for each shard
|
|
104
|
+
shard_tasks = []
|
|
105
|
+
for shard_id in shard_ids:
|
|
106
|
+
shard_tasks.append(
|
|
107
|
+
{
|
|
108
|
+
"input_path": cfg.input_path,
|
|
109
|
+
"output_path": cfg.output_path,
|
|
110
|
+
"shard_id": shard_id,
|
|
111
|
+
"file_list": shard_dict[shard_id],
|
|
112
|
+
}
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Execute pipeline with zephyr
|
|
116
|
+
pipeline = (
|
|
117
|
+
Dataset.from_list(shard_tasks)
|
|
118
|
+
.map(process_shard)
|
|
119
|
+
.write_jsonl(f"{cfg.output_path}/.metrics/part-{{shard:05d}}.jsonl", skip_existing=True)
|
|
120
|
+
)
|
|
121
|
+
ctx = ZephyrContext(name="download-ar5iv")
|
|
122
|
+
ctx.execute(pipeline)
|
|
123
|
+
|
|
124
|
+
logger.info("Transfer completed successfully!")
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def ar5iv_step(
|
|
128
|
+
name: str = "raw/ar5iv",
|
|
129
|
+
*,
|
|
130
|
+
input_path: str,
|
|
131
|
+
max_files: int | None = None,
|
|
132
|
+
deps: list[StepSpec] | None = None,
|
|
133
|
+
output_path_prefix: str | None = None,
|
|
134
|
+
override_output_path: str | None = None,
|
|
135
|
+
) -> StepSpec:
|
|
136
|
+
"""Create a StepSpec that downloads and processes the Ar5iv dataset from a zip file."""
|
|
137
|
+
|
|
138
|
+
def _run(output_path: str) -> None:
|
|
139
|
+
download(Ar5ivDownloadConfig(input_path=input_path, output_path=output_path, max_files=max_files))
|
|
140
|
+
|
|
141
|
+
return StepSpec(
|
|
142
|
+
name=name,
|
|
143
|
+
fn=_run,
|
|
144
|
+
deps=deps or [],
|
|
145
|
+
hash_attrs={"input_path": input_path, "max_files": max_files},
|
|
146
|
+
output_path_prefix=output_path_prefix,
|
|
147
|
+
override_output_path=override_output_path,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@draccus.wrap()
|
|
152
|
+
def main(cfg: Ar5ivDownloadConfig) -> None:
|
|
153
|
+
"""CLI entrypoint for downloading and processing Ar5iv dataset."""
|
|
154
|
+
|
|
155
|
+
configure_logging(level=logging.INFO)
|
|
156
|
+
download(cfg)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright The Marin Authors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Streaming downloaders for biology and chemistry notation slices.
|
|
5
|
+
|
|
6
|
+
Each submodule defines a StepSpec factory for one source family
|
|
7
|
+
(RefSeq, RNAcentral, UniProt, PubChem, RCSB PDB, ChEMBL, MoleculeNet) that
|
|
8
|
+
streams from the upstream mirror, splits the stream into format-preserving
|
|
9
|
+
records via :mod:`marin.transform.bio_chem`, packs short records into longer
|
|
10
|
+
documents for in-context-learning evaluation, and writes the result to
|
|
11
|
+
plain-text-in-parquet that Levanter can read directly.
|
|
12
|
+
|
|
13
|
+
The shared streaming primitives live in :mod:`._runtime`.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from marin.datakit.download.bio_chem._runtime import (
|
|
17
|
+
NotationFormat,
|
|
18
|
+
NotationSliceSpec,
|
|
19
|
+
PackingConfig,
|
|
20
|
+
bio_chem_slice_step,
|
|
21
|
+
run_notation_slice,
|
|
22
|
+
)
|