erictransformer 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- erictransformer/__init__.py +44 -0
- erictransformer/args/__init__.py +7 -0
- erictransformer/args/eric_args.py +50 -0
- erictransformer/eric_tasks/__init__.py +47 -0
- erictransformer/eric_tasks/args/__init__.py +16 -0
- erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
- erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
- erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
- erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
- erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
- erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
- erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
- erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
- erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
- erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
- erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
- erictransformer/eric_tasks/chat_templates/convert.py +67 -0
- erictransformer/eric_tasks/eric_chat.py +369 -0
- erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
- erictransformer/eric_tasks/eric_generation.py +243 -0
- erictransformer/eric_tasks/eric_text_classification.py +231 -0
- erictransformer/eric_tasks/eric_text_to_text.py +283 -0
- erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
- erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
- erictransformer/eric_tasks/misc/__init__.py +11 -0
- erictransformer/eric_tasks/misc/call_utils.py +69 -0
- erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
- erictransformer/eric_tasks/misc/rag.py +17 -0
- erictransformer/eric_tasks/results/__init__.py +6 -0
- erictransformer/eric_tasks/results/call_results.py +30 -0
- erictransformer/eric_tasks/tok/__init__.py +0 -0
- erictransformer/eric_tasks/tok/tok_functions.py +118 -0
- erictransformer/eric_tracker/__init__.py +1 -0
- erictransformer/eric_tracker/eric_tracker.py +256 -0
- erictransformer/eric_tracker/save_plot.py +422 -0
- erictransformer/eric_transformer.py +534 -0
- erictransformer/eval_models/__init__.py +1 -0
- erictransformer/eval_models/eval_model.py +75 -0
- erictransformer/exceptions/__init__.py +19 -0
- erictransformer/exceptions/eric_exceptions.py +74 -0
- erictransformer/loops/__init__.py +2 -0
- erictransformer/loops/eval_loop.py +111 -0
- erictransformer/loops/train_loop.py +310 -0
- erictransformer/utils/__init__.py +21 -0
- erictransformer/utils/init/__init__.py +5 -0
- erictransformer/utils/init/get_components.py +204 -0
- erictransformer/utils/init/get_device.py +22 -0
- erictransformer/utils/init/get_logger.py +15 -0
- erictransformer/utils/load_from_repo_or_path.py +14 -0
- erictransformer/utils/test/__init__.py +1 -0
- erictransformer/utils/test/debug_hook.py +20 -0
- erictransformer/utils/timer/__init__.py +1 -0
- erictransformer/utils/timer/eric_timer.py +145 -0
- erictransformer/utils/tok_data/__init__.py +8 -0
- erictransformer/utils/tok_data/num_proc.py +15 -0
- erictransformer/utils/tok_data/save_tok_data.py +36 -0
- erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
- erictransformer/utils/tok_data/tok_helpers.py +79 -0
- erictransformer/utils/train/__init__.py +6 -0
- erictransformer/utils/train/confirm_optimizer.py +18 -0
- erictransformer/utils/train/create_dir.py +72 -0
- erictransformer/utils/train/get_num_training_steps.py +15 -0
- erictransformer/utils/train/get_precision.py +22 -0
- erictransformer/utils/train/get_tok_data.py +105 -0
- erictransformer/utils/train/resume.py +62 -0
- erictransformer/validator/__init__.py +11 -0
- erictransformer/validator/eric/__init__.py +2 -0
- erictransformer/validator/eric/eval_validator.py +75 -0
- erictransformer/validator/eric/train_validator.py +143 -0
- erictransformer/validator/eric_validator.py +10 -0
- erictransformer/validator/tasks/__init__.py +5 -0
- erictransformer/validator/tasks/chat_validator.py +28 -0
- erictransformer/validator/tasks/gen_validator.py +28 -0
- erictransformer/validator/tasks/task_validator.py +54 -0
- erictransformer/validator/tasks/tc_validator.py +45 -0
- erictransformer/validator/tasks/tt_validator.py +28 -0
- erictransformer/validator/tok/__init__.py +1 -0
- erictransformer/validator/tok/tok_validator.py +23 -0
- erictransformer-0.0.1.dist-info/METADATA +72 -0
- erictransformer-0.0.1.dist-info/RECORD +83 -0
- erictransformer-0.0.1.dist-info/WHEEL +5 -0
- erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
- erictransformer-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
|
|
3
|
+
import huggingface_hub
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def load_from_repo_or_path(path: str, path_in_repo: str) -> str:
|
|
7
|
+
# Resolve a path that is either local or a repo into a local path.
|
|
8
|
+
if pathlib.Path(path).exists():
|
|
9
|
+
return path
|
|
10
|
+
else:
|
|
11
|
+
local_repo_dir = huggingface_hub.snapshot_download(
|
|
12
|
+
repo_id=path, repo_type="dataset", token=True
|
|
13
|
+
)
|
|
14
|
+
return str(pathlib.Path(local_repo_dir) / path_in_repo)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .debug_hook import DebugHook
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
class DebugHook:
|
|
5
|
+
hook: typing.Callable | None
|
|
6
|
+
|
|
7
|
+
def __init__(self) -> None:
|
|
8
|
+
self.hook = None
|
|
9
|
+
|
|
10
|
+
def __call__(self, *a, **k) -> typing.Any:
|
|
11
|
+
if self.hook:
|
|
12
|
+
return self.hook(*a,**k)
|
|
13
|
+
|
|
14
|
+
@contextlib.contextmanager
|
|
15
|
+
def set(self, hook):
|
|
16
|
+
self.hook = hook
|
|
17
|
+
try:
|
|
18
|
+
yield
|
|
19
|
+
finally:
|
|
20
|
+
self.hook = None
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from erictransformer.utils.timer.eric_timer import EricTimer
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class EricTimer:
|
|
8
|
+
def __init__(self, out_dir: str, enabled: bool = True):
|
|
9
|
+
self.enabled = enabled # if false self.report() does nothing.
|
|
10
|
+
self._stats = {}
|
|
11
|
+
self._active = {}
|
|
12
|
+
self._out_dir = out_dir
|
|
13
|
+
self._category_dir = os.path.join(out_dir, "categories")
|
|
14
|
+
self._ensure_dirs()
|
|
15
|
+
self.start_time = time.perf_counter()
|
|
16
|
+
|
|
17
|
+
@contextmanager
|
|
18
|
+
def section(self, category: str, label: str):
|
|
19
|
+
self.start(category, label)
|
|
20
|
+
try:
|
|
21
|
+
yield
|
|
22
|
+
finally:
|
|
23
|
+
self.stop(category, label)
|
|
24
|
+
|
|
25
|
+
def start(self, category: str, label: str):
|
|
26
|
+
if self._active:
|
|
27
|
+
raise ValueError(f"Eric Timer is already running! with {self._active}")
|
|
28
|
+
if label == "misc_time":
|
|
29
|
+
raise ValueError("misc_time is an invalid label")
|
|
30
|
+
|
|
31
|
+
if category not in self._active:
|
|
32
|
+
self._active[category] = {}
|
|
33
|
+
|
|
34
|
+
if label in self._active[category]:
|
|
35
|
+
raise ValueError(f"{category!r}/{label!r} is active. Call stop() first.")
|
|
36
|
+
|
|
37
|
+
self._active[category][label] = time.perf_counter()
|
|
38
|
+
|
|
39
|
+
def stop(self, category: str, label: str):
|
|
40
|
+
if category not in self._active or label not in self._active[category]:
|
|
41
|
+
raise RuntimeError(
|
|
42
|
+
f"stop() called without calling start for {category!r}/{label!r} first."
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
t0 = self._active[category][label]
|
|
46
|
+
del self._active[category][label]
|
|
47
|
+
if not self._active[category]:
|
|
48
|
+
del self._active[category]
|
|
49
|
+
|
|
50
|
+
dt = time.perf_counter() - t0
|
|
51
|
+
|
|
52
|
+
if category not in self._stats:
|
|
53
|
+
self._stats[category] = {}
|
|
54
|
+
|
|
55
|
+
if label not in self._stats[category]:
|
|
56
|
+
self._stats[category][label] = {"total": 0.0, "count": 0}
|
|
57
|
+
|
|
58
|
+
self._stats[category][label]["total"] += dt
|
|
59
|
+
self._stats[category][label]["count"] += 1
|
|
60
|
+
|
|
61
|
+
return dt
|
|
62
|
+
|
|
63
|
+
def _ensure_dirs(self):
|
|
64
|
+
if self.enabled:
|
|
65
|
+
if self._out_dir:
|
|
66
|
+
os.makedirs(self._out_dir, exist_ok=True)
|
|
67
|
+
if self._category_dir:
|
|
68
|
+
os.makedirs(self._category_dir, exist_ok=True)
|
|
69
|
+
|
|
70
|
+
def _category_path(self, category: str):
|
|
71
|
+
fname = f"{category}.csv"
|
|
72
|
+
return os.path.join(self._category_dir, fname)
|
|
73
|
+
|
|
74
|
+
def _write_category_json(self, category: str):
|
|
75
|
+
if category in self._stats:
|
|
76
|
+
labels = self._stats[category]
|
|
77
|
+
else:
|
|
78
|
+
labels = {}
|
|
79
|
+
|
|
80
|
+
rows = []
|
|
81
|
+
for label, s in labels.items():
|
|
82
|
+
total = float(s["total"])
|
|
83
|
+
count = int(s["count"])
|
|
84
|
+
mean = (total / count) if count else 0.0
|
|
85
|
+
rows.append(
|
|
86
|
+
{
|
|
87
|
+
"label": label,
|
|
88
|
+
"total": round(total, 4),
|
|
89
|
+
"mean": round(mean, 4),
|
|
90
|
+
"count": count,
|
|
91
|
+
}
|
|
92
|
+
)
|
|
93
|
+
rows.sort(key=lambda r: r["total"], reverse=True)
|
|
94
|
+
|
|
95
|
+
with open(
|
|
96
|
+
self._category_path(category), "w", newline="", encoding="utf-8"
|
|
97
|
+
) as f:
|
|
98
|
+
writer = csv.DictWriter(f, fieldnames=["label", "total", "mean", "count"])
|
|
99
|
+
writer.writeheader()
|
|
100
|
+
for r in rows:
|
|
101
|
+
writer.writerow(r)
|
|
102
|
+
|
|
103
|
+
def _write_categories_summary_json(self):
|
|
104
|
+
|
|
105
|
+
if self.enabled:
|
|
106
|
+
path = os.path.join(self._out_dir, "summary.csv")
|
|
107
|
+
rows = []
|
|
108
|
+
for category, labels in self._stats.items():
|
|
109
|
+
total = 0.0
|
|
110
|
+
for s in labels.values():
|
|
111
|
+
total += float(s["total"])
|
|
112
|
+
rows.append({"category": category, "total": round(total, 4)})
|
|
113
|
+
rows.sort(key=lambda r: r["total"], reverse=True)
|
|
114
|
+
|
|
115
|
+
with open(path, "w", newline="", encoding="utf-8") as f:
|
|
116
|
+
writer = csv.DictWriter(f, fieldnames=["category", "total"])
|
|
117
|
+
writer.writeheader()
|
|
118
|
+
for r in rows:
|
|
119
|
+
writer.writerow(r)
|
|
120
|
+
|
|
121
|
+
def report(self):
|
|
122
|
+
if self.enabled:
|
|
123
|
+
accounted_time = 0.0
|
|
124
|
+
for labels in self._stats.values():
|
|
125
|
+
for s in labels.values():
|
|
126
|
+
accounted_time += float(s["total"])
|
|
127
|
+
|
|
128
|
+
misc_time = time.perf_counter() - self.start_time - accounted_time
|
|
129
|
+
if misc_time < 0:
|
|
130
|
+
misc_time = 0.0
|
|
131
|
+
|
|
132
|
+
misc_cat = "misc"
|
|
133
|
+
if misc_cat not in self._stats:
|
|
134
|
+
self._stats[misc_cat] = {}
|
|
135
|
+
self._stats[misc_cat]["misc_time"] = {"total": misc_time, "count": 1}
|
|
136
|
+
|
|
137
|
+
for category in list(self._stats.keys()):
|
|
138
|
+
self._write_category_json(category)
|
|
139
|
+
|
|
140
|
+
self._write_categories_summary_json()
|
|
141
|
+
|
|
142
|
+
def reset(self):
|
|
143
|
+
self._stats.clear()
|
|
144
|
+
self._active.clear()
|
|
145
|
+
self.start_time = time.perf_counter()
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from erictransformer.utils.tok_data.num_proc import get_procs
|
|
2
|
+
from erictransformer.utils.tok_data.save_tok_data import save_json_tok_data
|
|
3
|
+
from erictransformer.utils.tok_data.tok_data_to_dataset import tok_dir_to_dataset
|
|
4
|
+
from erictransformer.utils.tok_data.tok_helpers import (
|
|
5
|
+
prepare_output_locations,
|
|
6
|
+
resolve_input_files,
|
|
7
|
+
write_details_file,
|
|
8
|
+
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import os
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_procs(procs: int) -> Optional[int]:
|
|
7
|
+
if procs == -1:
|
|
8
|
+
return None
|
|
9
|
+
elif procs == 0:
|
|
10
|
+
cpu_count = os.cpu_count()
|
|
11
|
+
if cpu_count is None or cpu_count <= 1:
|
|
12
|
+
return 1
|
|
13
|
+
return math.floor(cpu_count / 2)
|
|
14
|
+
|
|
15
|
+
return procs
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from erictransformer.exceptions import EricDatasetError, EricIOError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def save_json_tok_data(tok_dataset, output_files, shards, formatter):
|
|
7
|
+
total = len(tok_dataset)
|
|
8
|
+
per_shard = total // shards
|
|
9
|
+
remainder = total % shards
|
|
10
|
+
|
|
11
|
+
index = 0
|
|
12
|
+
|
|
13
|
+
for shard_id, output_file in enumerate(output_files):
|
|
14
|
+
count = per_shard + (1 if shard_id < remainder else 0)
|
|
15
|
+
|
|
16
|
+
for _ in range(count):
|
|
17
|
+
if index >= total:
|
|
18
|
+
break
|
|
19
|
+
try:
|
|
20
|
+
raw = formatter(tok_dataset[index])
|
|
21
|
+
except Exception as e:
|
|
22
|
+
raise EricDatasetError(
|
|
23
|
+
f"Failed to format dataset example at index {index}: {e}"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
example = {
|
|
28
|
+
k: v.tolist() if hasattr(v, "tolist") else v for k, v in raw.items()
|
|
29
|
+
}
|
|
30
|
+
output_file.write(json.dumps(example) + "\n")
|
|
31
|
+
except Exception as e:
|
|
32
|
+
raise EricIOError(
|
|
33
|
+
f"Failed to write formatted example at index {index} to file: {e}"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
index += 1
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from datasets import Dataset
|
|
5
|
+
|
|
6
|
+
from erictransformer.exceptions import EricDatasetError, EricInputError, EricIOError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def tok_dir_to_dataset(tok_dir: str):
|
|
10
|
+
# Validate the input directory
|
|
11
|
+
if not os.path.isdir(tok_dir):
|
|
12
|
+
raise EricInputError(f"Invalid tok_dir: {tok_dir} is not a directory.")
|
|
13
|
+
|
|
14
|
+
# Attempt to construct the path to eric_details.json
|
|
15
|
+
try:
|
|
16
|
+
eric_details_path = os.path.join(tok_dir, "eric_details.json")
|
|
17
|
+
except Exception as e:
|
|
18
|
+
raise EricInputError(f"Error creating path to eric_details.json: {e}")
|
|
19
|
+
|
|
20
|
+
# Read the eric_details.json file
|
|
21
|
+
try:
|
|
22
|
+
with open(eric_details_path, "r") as f:
|
|
23
|
+
eric_details = json.load(f)
|
|
24
|
+
except Exception as e:
|
|
25
|
+
raise EricIOError(f"Unable to read '{eric_details_path}': {e}")
|
|
26
|
+
|
|
27
|
+
# List of full paths to all JSONL files
|
|
28
|
+
full_paths = eric_details.get("paths", [])
|
|
29
|
+
if not full_paths:
|
|
30
|
+
raise EricDatasetError("No file paths found in eric_details.json.")
|
|
31
|
+
|
|
32
|
+
# Load all entries from each JSONL file
|
|
33
|
+
all_data = []
|
|
34
|
+
for path in full_paths:
|
|
35
|
+
try:
|
|
36
|
+
with open(path, "r") as f:
|
|
37
|
+
for line in f:
|
|
38
|
+
all_data.append(json.loads(line))
|
|
39
|
+
except Exception as e:
|
|
40
|
+
raise EricIOError(f"Unable to read '{path}': {e}")
|
|
41
|
+
|
|
42
|
+
# Convert list of dicts to Hugging Face Dataset
|
|
43
|
+
try:
|
|
44
|
+
dataset = Dataset.from_list(all_data)
|
|
45
|
+
except Exception as e:
|
|
46
|
+
raise EricDatasetError(f"Unable to create dataset: {e}")
|
|
47
|
+
|
|
48
|
+
return dataset
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import List, Tuple, Union
|
|
6
|
+
|
|
7
|
+
from datasets import Dataset
|
|
8
|
+
|
|
9
|
+
from erictransformer.exceptions import EricInputError, EricIOError
|
|
10
|
+
|
|
11
|
+
EXT_TO_TYPE_MAP = {
|
|
12
|
+
".json": "json",
|
|
13
|
+
".jsonl": "json", # jsonl is loaded as json
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def resolve_input_files(input_data: Union[str, Dataset]) -> List[tuple]:
|
|
18
|
+
try:
|
|
19
|
+
if isinstance(input_data, str):
|
|
20
|
+
if os.path.isdir(input_data):
|
|
21
|
+
return collect_files_from_dir(input_data)
|
|
22
|
+
elif os.path.isfile(input_data):
|
|
23
|
+
ext = os.path.splitext(input_data)[1].lower()
|
|
24
|
+
if ext not in EXT_TO_TYPE_MAP:
|
|
25
|
+
raise EricInputError(f"Unsupported file extension: {ext}")
|
|
26
|
+
return [(input_data, EXT_TO_TYPE_MAP[ext])]
|
|
27
|
+
else:
|
|
28
|
+
raise EricInputError(f"Path does not exist: {input_data}")
|
|
29
|
+
else:
|
|
30
|
+
raise EricInputError("Input must be a file path or directory string.")
|
|
31
|
+
except Exception as e:
|
|
32
|
+
raise EricInputError(f"Failed to resolve input files: {e}")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def collect_files_from_dir(dir_path: str) -> List[Tuple[str, str]]:
|
|
36
|
+
try:
|
|
37
|
+
jsonl = [(f, "json") for f in glob.glob(os.path.join(dir_path, "*.jsonl"))]
|
|
38
|
+
# text = [(f, "text") for f in glob.glob(os.path.join(dir_path, "*.txt"))]
|
|
39
|
+
# we might support .txt files for EricGeneration models one day
|
|
40
|
+
all_files = jsonl
|
|
41
|
+
if not all_files:
|
|
42
|
+
raise EricInputError(
|
|
43
|
+
f"No valid jsonl files found in directory: {dir_path}"
|
|
44
|
+
)
|
|
45
|
+
return all_files
|
|
46
|
+
except Exception as e:
|
|
47
|
+
raise EricIOError(f"Failed to collect files from directory '{dir_path}': {e}")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def prepare_output_locations(out_dir: str, shards: int):
|
|
51
|
+
try:
|
|
52
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
53
|
+
paths = [
|
|
54
|
+
os.path.join(out_dir, f"tok_shard_{i + 1}.jsonl")
|
|
55
|
+
for i in range(shards)
|
|
56
|
+
]
|
|
57
|
+
handles = []
|
|
58
|
+
for p in paths:
|
|
59
|
+
try:
|
|
60
|
+
handles.append(open(p, "w"))
|
|
61
|
+
except OSError as e:
|
|
62
|
+
raise EricIOError(f"Failed to open file for writing: {p} - {e}")
|
|
63
|
+
return paths, handles
|
|
64
|
+
except Exception as e:
|
|
65
|
+
raise EricIOError(f"Failed to prepare output locations in '{out_dir}': {e}")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def write_details_file(out_dir: str, num_cases: int, output_paths: List[str]):
|
|
69
|
+
try:
|
|
70
|
+
detail = {
|
|
71
|
+
"num_cases": num_cases,
|
|
72
|
+
"timestamp": datetime.now().isoformat(),
|
|
73
|
+
"paths": output_paths,
|
|
74
|
+
}
|
|
75
|
+
details_path = os.path.join(out_dir, "eric_details.json")
|
|
76
|
+
with open(details_path, "w") as f:
|
|
77
|
+
json.dump(detail, f, indent=2)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
raise EricIOError(f"Failed to write details file in '{out_dir}': {e}")
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from erictransformer.utils.train.confirm_optimizer import get_optim
|
|
2
|
+
from erictransformer.utils.train.create_dir import create_tracker_dir, make_dir
|
|
3
|
+
from erictransformer.utils.train.get_num_training_steps import get_num_training_steps
|
|
4
|
+
from erictransformer.utils.train.get_precision import get_precision
|
|
5
|
+
from erictransformer.utils.train.get_tok_data import get_tok_data
|
|
6
|
+
from erictransformer.utils.train.resume import resume_training
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from torch.optim import SGD, AdamW
|
|
2
|
+
|
|
3
|
+
from erictransformer.args import EricTrainArgs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_optim(args: EricTrainArgs, model, logger):
|
|
7
|
+
if args.optim == "adamw":
|
|
8
|
+
logger.debug("Using PyTorch's AdamW optimizer.")
|
|
9
|
+
return AdamW(model.parameters(), lr=args.lr, weight_decay=0.05)
|
|
10
|
+
|
|
11
|
+
elif args.optim == "sgd":
|
|
12
|
+
logger.debug("Using PyTorch's SGD optimizer")
|
|
13
|
+
return SGD(
|
|
14
|
+
model.parameters(), lr=args.lr, momentum=0, weight_decay=1e-4
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
else:
|
|
18
|
+
raise ValueError(f"Unsupported optimizer: {args.optim}")
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import datetime as dt
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from erictransformer.exceptions import EricIOError
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _slugify(name: str) -> str:
|
|
11
|
+
safe_chars = re.compile(r"[^A-Za-z0-9._-]+")
|
|
12
|
+
s = name.strip()
|
|
13
|
+
s = s.replace(" ", "-")
|
|
14
|
+
s = safe_chars.sub("-", s)
|
|
15
|
+
s = re.sub(r"-{2,}", "-", s)
|
|
16
|
+
s = s.strip("-._")
|
|
17
|
+
if not s:
|
|
18
|
+
raise EricIOError(f"Invalid directory name: {name}")
|
|
19
|
+
return s
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _increment_path(base_path: Path) -> Path:
|
|
23
|
+
parent = base_path.parent
|
|
24
|
+
name = base_path.name
|
|
25
|
+
|
|
26
|
+
m = re.match(r"^(.*?)(?:-(\d+))?$", name)
|
|
27
|
+
root = m.group(1)
|
|
28
|
+
pat = re.compile(rf"^{re.escape(root)}(?:-(\d+))?$")
|
|
29
|
+
|
|
30
|
+
max_i = -1
|
|
31
|
+
if parent.exists():
|
|
32
|
+
for child in parent.iterdir():
|
|
33
|
+
m2 = pat.match(child.name)
|
|
34
|
+
if m2:
|
|
35
|
+
i = int(m2.group(1)) if m2 and m2.group(1) else 0
|
|
36
|
+
if i > max_i:
|
|
37
|
+
max_i = i
|
|
38
|
+
|
|
39
|
+
if max_i < 0:
|
|
40
|
+
return base_path
|
|
41
|
+
next_i = max_i + 1
|
|
42
|
+
if next_i == 0:
|
|
43
|
+
return root
|
|
44
|
+
else:
|
|
45
|
+
return parent / f"{name}-{next_i}"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def make_dir(name):
|
|
49
|
+
try:
|
|
50
|
+
os.makedirs(name, exist_ok=True)
|
|
51
|
+
logging.info(f"Directory {name} created successfully.")
|
|
52
|
+
except FileExistsError:
|
|
53
|
+
logging.warning(f"Directory {name} already exists.")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def create_tracker_dir(out_dir: str, label: str, run_name: str):
|
|
57
|
+
timestamp = dt.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
58
|
+
if run_name:
|
|
59
|
+
slugify_run_name = _slugify(run_name)
|
|
60
|
+
|
|
61
|
+
output = Path(out_dir)
|
|
62
|
+
output.mkdir(parents=True, exist_ok=True)
|
|
63
|
+
|
|
64
|
+
dir_name = _increment_path(output / slugify_run_name)
|
|
65
|
+
|
|
66
|
+
make_dir(dir_name)
|
|
67
|
+
|
|
68
|
+
else:
|
|
69
|
+
dir_name = os.path.join(out_dir, f"{label}_{timestamp}")
|
|
70
|
+
make_dir(dir_name)
|
|
71
|
+
|
|
72
|
+
return dir_name
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_precision(device) -> str:
|
|
5
|
+
is_cuda = device.type == "cuda"
|
|
6
|
+
bf16_ok = (
|
|
7
|
+
bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)())
|
|
8
|
+
if is_cuda
|
|
9
|
+
else False
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
use_bf16 = is_cuda and bf16_ok
|
|
13
|
+
use_fp16 = is_cuda and not use_bf16
|
|
14
|
+
|
|
15
|
+
if use_bf16:
|
|
16
|
+
precision_type = "bf16"
|
|
17
|
+
elif use_fp16:
|
|
18
|
+
precision_type = "fp16"
|
|
19
|
+
else:
|
|
20
|
+
precision_type = "fp32"
|
|
21
|
+
|
|
22
|
+
return precision_type
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
from datasets import load_dataset
|
|
6
|
+
from torch import device
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
from transformers import DataCollator
|
|
9
|
+
|
|
10
|
+
from erictransformer.exceptions import EricDatasetError, EricInputError
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_tok_data(
|
|
14
|
+
tok_dir: str,
|
|
15
|
+
seed: int,
|
|
16
|
+
bs: int,
|
|
17
|
+
collate_fn: DataCollator,
|
|
18
|
+
device_type: device,
|
|
19
|
+
):
|
|
20
|
+
details = _load_details_file(tok_dir)
|
|
21
|
+
paths = _resolve_dataset_paths(details, tok_dir)
|
|
22
|
+
dataset = _load_streaming_dataset(paths, seed)
|
|
23
|
+
dataloader = _build_dataloader(dataset, bs, collate_fn, device_type)
|
|
24
|
+
|
|
25
|
+
return dataloader, details["num_cases"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _load_details_file(tok_dir: str) -> dict:
|
|
29
|
+
details_filepath = Path(tok_dir) / "eric_details.json"
|
|
30
|
+
if not details_filepath.is_file():
|
|
31
|
+
raise FileNotFoundError(f"Missing required file: {details_filepath}")
|
|
32
|
+
try:
|
|
33
|
+
with open(details_filepath, "r", encoding="utf-8") as f:
|
|
34
|
+
return json.load(f)
|
|
35
|
+
except Exception as e:
|
|
36
|
+
raise EricInputError(f"Failed to parse eric_details.json: {e}")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _resolve_dataset_paths(details: dict, tok_dir: str) -> List[Path]:
|
|
40
|
+
try:
|
|
41
|
+
paths = [Path(p) for p in details["paths"]]
|
|
42
|
+
except KeyError:
|
|
43
|
+
raise EricInputError(
|
|
44
|
+
f"'paths' field missing in eric_details.json in {tok_dir}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
if not all(p.is_file() and p.suffix == ".jsonl" for p in paths):
|
|
48
|
+
raise EricInputError("All dataset paths must exist and be `.jsonl` files")
|
|
49
|
+
return paths
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _load_streaming_dataset(paths: List[Path], seed: int):
|
|
53
|
+
data_files = {"train": [str(p) for p in paths]}
|
|
54
|
+
try:
|
|
55
|
+
ds = load_dataset(
|
|
56
|
+
"json",
|
|
57
|
+
data_files=data_files,
|
|
58
|
+
split="train",
|
|
59
|
+
streaming=True,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Probe for at least one example (to handle the "all shards empty" case)
|
|
63
|
+
it = iter(ds)
|
|
64
|
+
try:
|
|
65
|
+
_ = next(it)
|
|
66
|
+
except StopIteration:
|
|
67
|
+
raise EricInputError("No non-empty datasets to stream")
|
|
68
|
+
|
|
69
|
+
# Re-create the iterator since we advanced one element when probing
|
|
70
|
+
ds = load_dataset("json", data_files=data_files, split="train", streaming=True)
|
|
71
|
+
|
|
72
|
+
ds = ds.shuffle(buffer_size=10_000, seed=seed)
|
|
73
|
+
return ds
|
|
74
|
+
except EricInputError:
|
|
75
|
+
raise
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise EricDatasetError(
|
|
78
|
+
f"Failed to load streaming dataset from paths {paths}: {e}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _build_dataloader(
|
|
83
|
+
dataset, bs: int, collate_fn: DataCollator, device_type: device
|
|
84
|
+
):
|
|
85
|
+
try:
|
|
86
|
+
shards = getattr(dataset, "n_shards", 1)
|
|
87
|
+
base_workers = min(4, shards)
|
|
88
|
+
|
|
89
|
+
if device_type.type == "mps":
|
|
90
|
+
workers = 0
|
|
91
|
+
else:
|
|
92
|
+
workers = base_workers
|
|
93
|
+
|
|
94
|
+
return DataLoader(
|
|
95
|
+
dataset,
|
|
96
|
+
batch_size=bs,
|
|
97
|
+
collate_fn=collate_fn,
|
|
98
|
+
num_workers=workers,
|
|
99
|
+
pin_memory=(device_type.type == "cuda"),
|
|
100
|
+
persistent_workers=(workers > 0),
|
|
101
|
+
prefetch_factor=4 if workers > 0 else None,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
except Exception as e:
|
|
105
|
+
raise EricInputError(f"Failed to create DataLoader: {e}")
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from erictransformer.exceptions import EricResumeError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def resume_training(resume_path: str):
|
|
9
|
+
tracker_state = _load_tracker_state(resume_path)
|
|
10
|
+
args_dict = _load_train_args(resume_path)
|
|
11
|
+
model_tokenizer_path, lr_sched_path = _resolve_training_paths(
|
|
12
|
+
resume_path
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
return (
|
|
16
|
+
tracker_state,
|
|
17
|
+
args_dict,
|
|
18
|
+
model_tokenizer_path,
|
|
19
|
+
lr_sched_path,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _load_tracker_state(resume_path: str):
|
|
24
|
+
logs_path = Path(os.path.join(resume_path, "eric_tracker_state.json"))
|
|
25
|
+
try:
|
|
26
|
+
if not logs_path.exists():
|
|
27
|
+
raise EricResumeError(f"No log file found at {logs_path}.")
|
|
28
|
+
|
|
29
|
+
with logs_path.open(encoding="utf-8") as f:
|
|
30
|
+
tracker_state = json.load(f)
|
|
31
|
+
|
|
32
|
+
if tracker_state is None:
|
|
33
|
+
raise EricResumeError("Loaded tracker state is None.")
|
|
34
|
+
|
|
35
|
+
return tracker_state
|
|
36
|
+
except Exception as e:
|
|
37
|
+
raise EricResumeError(f"Failed to load tracker state: {e}")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _load_train_args(resume_path: str):
|
|
41
|
+
train_args_path = Path(os.path.join(resume_path, "train_args.json"))
|
|
42
|
+
try:
|
|
43
|
+
if not train_args_path.exists():
|
|
44
|
+
raise EricResumeError(f"No train_args.json found at {train_args_path}.")
|
|
45
|
+
|
|
46
|
+
args_dict = json.loads(train_args_path.read_text(encoding="utf-8"))
|
|
47
|
+
return args_dict
|
|
48
|
+
except Exception as e:
|
|
49
|
+
raise EricResumeError(f"Failed to load training arguments: {e}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _resolve_training_paths(resume_path: str):
|
|
53
|
+
try:
|
|
54
|
+
lr_sched_path = os.path.join(resume_path, "lr_sched.pt")
|
|
55
|
+
|
|
56
|
+
if not Path(lr_sched_path).exists():
|
|
57
|
+
raise EricResumeError(f"lr_sched.pt not found at {lr_sched_path}")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
return resume_path, lr_sched_path
|
|
61
|
+
except Exception as e:
|
|
62
|
+
raise EricResumeError(f"Failed to resolve training paths: {e}")
|