sarasa 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sarasa/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from .config import Config as Config
2
+ from .trainer import Trainer as Trainer
@@ -0,0 +1,81 @@
1
+ from collections import defaultdict
2
+
3
+ import torch
4
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
5
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
6
+
7
+ # for selective op activation checkpointing
8
+ _ops_sac_save = {
9
+ torch.ops.aten.mm.default,
10
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
11
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
12
+ torch.ops.aten._scaled_dot_product_cudnn_attention.default,
13
+ torch.ops.aten._scaled_dot_product_attention_math.default,
14
+ torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default,
15
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
16
+ # for low precision training, it's useful to always save
17
+ # the result of max, since the absolute maximum is
18
+ # used to compute the scaling factor for quantization.
19
+ torch.ops.aten.max.default,
20
+ torch._higher_order_ops.inductor_compiled_code,
21
+ }
22
+
23
+
24
+ def _op_sac_policy(
25
+ ops_to_save: set,
26
+ mm_recompute_shapes: set | None,
27
+ every_nth_mm: int,
28
+ ):
29
+ mm_recompute_shapes = mm_recompute_shapes or set()
30
+
31
+ def _get_custom_policy(meta: dict):
32
+ def _custom_policy(ctx, func, *args, **kwargs):
33
+ # special case, offload to CPU
34
+ if (
35
+ func == torch.ops.aten._to_copy.default
36
+ and "cuda" in str(args[0].device)
37
+ and str(kwargs.get("device", "")) == "cpu"
38
+ ):
39
+ return CheckpointPolicy.MUST_SAVE
40
+
41
+ # track mm ops
42
+ mode = "recompute" if ctx.is_recompute else "forward"
43
+ key = f"{mode}_mm_count"
44
+
45
+ if func == torch.ops.aten.mm.default:
46
+ if len(args) > 1 and args[1].shape in mm_recompute_shapes:
47
+ # moe's router
48
+ return CheckpointPolicy.PREFER_RECOMPUTE
49
+ meta[key] += 1
50
+
51
+ # save ops in save list, except every nth mm op
52
+ must_save = (func in ops_to_save) and not (
53
+ func == torch.ops.aten.mm.default and (meta[key] % every_nth_mm == 0)
54
+ )
55
+ return CheckpointPolicy.MUST_SAVE if must_save else CheckpointPolicy.PREFER_RECOMPUTE
56
+
57
+ return _custom_policy
58
+
59
+ def selective_checkpointing_context_fn():
60
+ return create_selective_checkpoint_contexts(_get_custom_policy(defaultdict(int)))
61
+
62
+ return selective_checkpointing_context_fn
63
+
64
+
65
+ def apply_op_sac(
66
+ model: torch.nn.Module,
67
+ ops_to_save: set | None = None,
68
+ mm_recompute_shapes: set | None = None,
69
+ every_nth_mm: int = 2,
70
+ ) -> torch.nn.Module:
71
+ """Applies selective op activation checkpointing to the given model.
72
+
73
+ Ops like mm is expensive, so we want to store their activations for backward.
74
+ On the other hand, ops like activation functions are cheap, so we prefer to recompute them.
75
+
76
+ """
77
+ ops_to_save = ops_to_save or _ops_sac_save
78
+ return checkpoint_wrapper(
79
+ model,
80
+ _op_sac_policy(ops_to_save, mm_recompute_shapes, every_nth_mm),
81
+ )
sarasa/checkpoint.py ADDED
@@ -0,0 +1,112 @@
1
+ import enum
2
+ import gc
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ import torch.distributed.checkpoint as dcp
9
+ from loguru import logger
10
+ from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
11
+ from torch.distributed.checkpoint.state_dict import get_model_state_dict
12
+ from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType
13
+ from torch.distributed.checkpoint.stateful import Stateful
14
+
15
+ from sarasa.config import Config
16
+
17
+
18
+ class AsyncMode(enum.StrEnum):
19
+ none = enum.auto()
20
+ default = enum.auto()
21
+ mem_pinned = enum.auto()
22
+
23
+
24
+ class ModelWrapper(Stateful):
25
+ def __init__(self, model: torch.nn.Module):
26
+ self.model = model
27
+
28
+ def state_dict(self) -> dict[str, torch.Tensor]:
29
+ return {"model": get_model_state_dict(self.model)}
30
+
31
+ def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
32
+ raise NotImplementedError("...")
33
+
34
+
35
+ class Checkpointer:
36
+ def __init__(
37
+ self,
38
+ config: Config,
39
+ model: torch.nn.Module,
40
+ ):
41
+ self.config = config
42
+ self.checkpoint_freq = config.checkpoint.save_freq
43
+ self.checkpoint_dir = Path(config.output_dir) / "checkpoints"
44
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
45
+ self.async_mode = AsyncMode(config.checkpoint.async_mode)
46
+ if self.async_mode != AsyncMode.none:
47
+ self.pg = dist.new_group(backend="gloo") if dist.is_initialized() else None
48
+
49
+ self.stager = None
50
+ self.save_future = None
51
+ self.stage_future = None
52
+
53
+ self.state = ModelWrapper(model)
54
+
55
+ @torch.no_grad()
56
+ def save(
57
+ self,
58
+ step: int,
59
+ ) -> None:
60
+ if step % self.checkpoint_freq != 0:
61
+ return
62
+
63
+ begin = time.perf_counter()
64
+ checkpoint_id = str(self.checkpoint_dir / f"checkpoint_{step:09d}")
65
+
66
+ # todo: save other states
67
+ state_dict = self.state.state_dict()
68
+
69
+ if self.async_mode == AsyncMode.default:
70
+ gc.collect(1)
71
+ if self.save_future is not None:
72
+ self.save_future.result()
73
+ self.save_future = dcp.async_save(
74
+ state_dict,
75
+ storage_writer=None,
76
+ checkpoint_id=checkpoint_id,
77
+ process_group=self.pg,
78
+ )
79
+ gc.collect(1)
80
+ elif self.async_mode == AsyncMode.mem_pinned:
81
+ gc.collect(1)
82
+ if self.save_future is not None:
83
+ self.save_future.result()
84
+ if self.stager is None:
85
+ self.stager = DefaultStager(StagingOptions(True, True, True, True))
86
+ ret = dcp.async_save(
87
+ state_dict,
88
+ storage_writer=None,
89
+ checkpoint_id=checkpoint_id,
90
+ process_group=self.pg,
91
+ async_checkpointer_type=AsyncCheckpointerType.PROCESS,
92
+ async_stager=self.stager,
93
+ )
94
+ self.save_future = ret.upload_completion
95
+ self.stage_future = ret.staging_completion
96
+ else:
97
+ ret = dcp.save(
98
+ state_dict,
99
+ storage_writer=None,
100
+ checkpoint_id=checkpoint_id,
101
+ )
102
+
103
+ logger.info(f"Finished saving checkpoint at step {step} in {time.perf_counter() - begin:.2f} seconds")
104
+
105
+ def wait_for_staging(self) -> None:
106
+ # no-op if not using mem_pinned async mode
107
+ if self.stage_future is not None:
108
+ self.stage_future.result()
109
+
110
+ def close(self) -> None:
111
+ if self.stager is not None:
112
+ self.stager.close()
sarasa/config.py ADDED
@@ -0,0 +1,279 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import Literal
7
+
8
+ import torch
9
+
10
+ """
11
+ Variable configuration dataclasses for model, optimizer, lr scheduler, and data
12
+ These classes have `create` methods to instantiate the actual objects
13
+
14
+ Users can define their own configuration dataclasses and pass them to Config.from_cli to use custom components
15
+ """
16
+ from sarasa.data import DataConfig as Data # noqa
17
+ from sarasa.models import ModelConfig as Model # noqa
18
+ from sarasa.optimizers import AdamW # noqa
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class LRScheduler:
23
+ warmup_steps: int = 200
24
+ decay_ratio: float | None = None
25
+ """If set, the ratio of total steps to apply decay after warmup. If None, decay starts immediately after warmup."""
26
+
27
+ decay_type: Literal["linear", "cosine", "sqrt"] = "linear"
28
+ min_lr_factor: float = 0.0
29
+
30
+ def create(
31
+ self,
32
+ optimizer: torch.optim.Optimizer,
33
+ total_iters: int,
34
+ ) -> torch.optim.lr_scheduler._LRScheduler:
35
+ assert self.decay_ratio is None or (0 <= self.decay_ratio <= 1), "decay_ratio must be between 0 and 1"
36
+ warmup_steps = self.warmup_steps
37
+ stay_steps = 0 if self.decay_ratio is None else int(total_iters * (1 - self.decay_ratio)) - warmup_steps
38
+ decay_steps = total_iters - warmup_steps - stay_steps
39
+ assert warmup_steps >= 0 and decay_steps >= 0 and stay_steps >= 0, (
40
+ f"Invalid lr scheduler steps configuration: {warmup_steps=}, {decay_steps=}, {stay_steps=}"
41
+ )
42
+
43
+ # 1 / max(1, warmup_steps) to avoid division by zero
44
+ warmup = torch.optim.lr_scheduler.LinearLR(optimizer, 1 / max(1, warmup_steps), total_iters=warmup_steps)
45
+
46
+ stay = torch.optim.lr_scheduler.ConstantLR(optimizer=optimizer, factor=1.0, total_iters=stay_steps)
47
+
48
+ match self.decay_type:
49
+ case "linear":
50
+ decay = torch.optim.lr_scheduler.LinearLR(
51
+ optimizer,
52
+ start_factor=1.0,
53
+ end_factor=self.min_lr_factor,
54
+ total_iters=decay_steps,
55
+ )
56
+ case "sqrt":
57
+ decay = torch.optim.lr_scheduler.LambdaLR(
58
+ optimizer,
59
+ lr_lambda=lambda step: max(
60
+ self.min_lr_factor,
61
+ (decay_steps - step) / decay_steps,
62
+ )
63
+ ** 0.5,
64
+ )
65
+ case "cosine":
66
+ decay = torch.optim.lr_scheduler.CosineAnnealingLR(
67
+ optimizer,
68
+ T_max=decay_steps,
69
+ eta_min=optimizer.param_groups[0]["lr"] * self.min_lr_factor,
70
+ )
71
+
72
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
73
+ optimizer,
74
+ [warmup, stay, decay],
75
+ [self.warmup_steps, self.warmup_steps + stay_steps],
76
+ )
77
+ return scheduler
78
+
79
+
80
+ """
81
+ Static configuration dataclasses
82
+
83
+ These classes are not expected to be changed by the user
84
+ """
85
+
86
+
87
+ @dataclasses.dataclass
88
+ class Train:
89
+ steps: int = 10_000
90
+
91
+ grad_clip: float | None = None
92
+
93
+ dtype: Literal["bfloat16", "float32"] = "float32"
94
+
95
+ compile: bool = False
96
+
97
+ gc_freq: int = 50
98
+ """Garbage collection frequency (in steps). If -1, no periodic GC is performed."""
99
+
100
+ local_batch_size: int = 32
101
+ """local (per device) batch size"""
102
+
103
+ global_batch_size: int = 256
104
+ """
105
+ global (across all devices) batch size, used to compute
106
+ grad_accum_steps = global_batch_size // (local_batch_size * num_devices)
107
+ """
108
+
109
+ use_fa4: bool = True
110
+ """Whether to use FA4 flash attention if available."""
111
+
112
+ val_freq: int = -1
113
+ """Validation frequency (in steps). If -1, no validation is performed."""
114
+
115
+ use_sac: bool = False
116
+ """Whether to use selective activation checkpointing."""
117
+
118
+
119
+ @dataclasses.dataclass
120
+ class Metrics:
121
+ log_freq: int = 10
122
+ use_tensorboard: bool = False
123
+ all_node: bool = False
124
+
125
+
126
+ @dataclasses.dataclass
127
+ class Checkpoint:
128
+ save_freq: int = 1000
129
+ async_mode: Literal["none", "default", "mem_pinned"] = "default"
130
+
131
+
132
+ @dataclasses.dataclass
133
+ class Distributed:
134
+ backend: Literal["nccl", "gloo"] = "nccl"
135
+
136
+ init_timeout_seconds: int = 300
137
+ """Timeout for initializing the distributed process group."""
138
+
139
+ train_timeout_seconds: int = 100
140
+ """Timeout for distributed training operations after the first iteration."""
141
+
142
+ @property
143
+ def name(self) -> str:
144
+ return self.__class__.__name__.lower()
145
+
146
+
147
+ @dataclasses.dataclass
148
+ class DDP(Distributed):
149
+ pass
150
+
151
+
152
+ @dataclasses.dataclass
153
+ class FSDP(Distributed):
154
+ reshard_after_forward: bool = False
155
+ """Whether to reshard model parameters after each forward pass (FSDP only)."""
156
+
157
+
158
+ @dataclasses.dataclass
159
+ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
160
+ # variable components
161
+ model: ModelT
162
+ optim: OptimizerT
163
+ lr_scheduler: LRSchedulerT
164
+ data: DataT
165
+
166
+ # static components
167
+ train: Train = dataclasses.field(default_factory=Train)
168
+ metrics: Metrics = dataclasses.field(default_factory=Metrics)
169
+ checkpoint: Checkpoint = dataclasses.field(default_factory=Checkpoint)
170
+ distributed: DDP | FSDP = dataclasses.field(default_factory=DDP)
171
+
172
+ seed: int = 0
173
+ debug: bool = False
174
+ """ Enable debug mode with more verbose logging and checks."""
175
+
176
+ output_dir: Path | str = Path("./outputs")
177
+ """Directory to save checkpoints and logs."""
178
+
179
+ config_file: Path | str | None = None
180
+ """Path to a config file (JSON or TOML) to load configuration from."""
181
+
182
+ def __post_init__(self):
183
+ if self.output_dir is not None:
184
+ self.output_dir.mkdir(parents=True, exist_ok=True)
185
+
186
+ if hasattr(self.model, "seq_len") and self.model.seq_len is None:
187
+ if self.data.seq_len is not None:
188
+ self.model.seq_len = self.data.seq_len
189
+ else:
190
+ raise ValueError("Either model.seq_len or data.seq_len must be set.")
191
+
192
+ @classmethod
193
+ def create(
194
+ cls,
195
+ model: ModelT,
196
+ optim: OptimizerT,
197
+ lr_scheduler: LRSchedulerT,
198
+ data: DataT,
199
+ **kwargs,
200
+ ) -> Config:
201
+ return cls(
202
+ model=model,
203
+ optim=optim,
204
+ lr_scheduler=lr_scheduler,
205
+ data=data,
206
+ **kwargs,
207
+ )
208
+
209
+ @classmethod
210
+ def from_cli(
211
+ cls,
212
+ *,
213
+ model_type: ModelT = Model,
214
+ optim_type: OptimizerT = AdamW,
215
+ lr_scheduler_type: LRSchedulerT = LRScheduler,
216
+ data_type: DataT = Data,
217
+ ) -> Config:
218
+ """
219
+ initialize JobConfig from command line arguments
220
+ update the values with the following priority: CLI arguments > config file > defaults
221
+
222
+ *_type can be used to specify custom dataclass types for each section
223
+ >> config = Config.from_cli(optim_type=CustomOptimizerConfig)
224
+ """
225
+
226
+ import importlib.util
227
+
228
+ import tyro
229
+
230
+ loaded_config = None
231
+
232
+ if (under := ("--config_file" in sys.argv)) or ("--config-file" in sys.argv):
233
+ config_file = sys.argv[sys.argv.index("--config_file" if under else "--config-file") + 1]
234
+ config_file = Path(config_file)
235
+
236
+ if not config_file.exists():
237
+ raise FileNotFoundError(f"Config file {config_file} does not exist.")
238
+
239
+ if config_file.suffix != ".py":
240
+ raise ValueError("Only Python config files are supported in this method.")
241
+
242
+ spec = importlib.util.spec_from_file_location("custom_config", config_file)
243
+ module = importlib.util.module_from_spec(spec)
244
+ spec.loader.exec_module(module)
245
+ configs = [
246
+ config
247
+ for config in module.__dict__.values()
248
+ if isinstance(config, cls) and not isinstance(config, type)
249
+ ]
250
+ if len(configs) == 0:
251
+ raise ValueError(f"No Config instance found in {config_file}.")
252
+ elif len(configs) > 1:
253
+ raise ValueError(f"Multiple Config instances found in {config_file}. Please keep only one.")
254
+ else:
255
+ loaded_config = configs[0]
256
+
257
+ return tyro.cli(
258
+ cls[
259
+ model_type,
260
+ optim_type,
261
+ lr_scheduler_type,
262
+ data_type,
263
+ ],
264
+ default=loaded_config,
265
+ )
266
+
267
+
268
+ __all__ = [
269
+ "Config",
270
+ "Model",
271
+ "AdamW",
272
+ "LRScheduler",
273
+ "Data",
274
+ "Train",
275
+ "Metrics",
276
+ "Checkpoint",
277
+ "DDP",
278
+ "FSDP",
279
+ ]
@@ -0,0 +1,36 @@
1
+ import dataclasses
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ from torch.utils.data import DataLoader
6
+
7
+ from sarasa.data.hf_datasets import Datasets, HFTextDataset
8
+ from sarasa.data.tokenizer import HFTokenizerWrapper
9
+
10
+
11
+ @dataclasses.dataclass
12
+ class DataConfig:
13
+ dataset: Datasets = Datasets.fineweb_edu_100b
14
+ """Dataset to use for training. Can be a predefined dataset or a custom dataset path."""
15
+
16
+ tokenizer_path: Path | str = Path("./tokenizer")
17
+ """Path to `tokenizer.json` and `tokenizer_config.json` files."""
18
+
19
+ seq_len: int = 2048
20
+
21
+ num_workers: int = 4
22
+ pin_memory: bool = True
23
+
24
+ def create(
25
+ self,
26
+ batch_size: int,
27
+ ) -> dict[str, Any]:
28
+ # return {"tokenizer": tokenizer, "train_loader": train_loader, "val_loader": val_loader | None}
29
+ tokenizer = HFTokenizerWrapper(Path(self.tokenizer_path))
30
+ ds = HFTextDataset(self.dataset, "train", tokenizer, self.seq_len)
31
+ data_loader = DataLoader(ds, batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory)
32
+
33
+ return {
34
+ "tokenizer": tokenizer,
35
+ "train_loader": data_loader,
36
+ }
@@ -0,0 +1,115 @@
1
+ import enum
2
+ from typing import Any, Callable
3
+
4
+ import torch
5
+ from datasets import disable_progress_bars, load_dataset
6
+ from datasets.distributed import split_dataset_by_node
7
+ from loguru import logger
8
+ from torch.utils.data import IterableDataset
9
+
10
+ from sarasa.utils import rank, world_size
11
+
12
+
13
+ class Datasets(enum.StrEnum):
14
+ c4 = enum.auto()
15
+ fineweb_edu = enum.auto()
16
+ fineweb_edu_100b = enum.auto()
17
+ fineweb_edu_dedup = enum.auto()
18
+
19
+ def load(
20
+ self,
21
+ cache_dir: str | None,
22
+ ) -> Any:
23
+ match self:
24
+ case Datasets.c4:
25
+ return load_dataset(
26
+ "allenai/c4",
27
+ name="en",
28
+ split="train",
29
+ streaming=True,
30
+ cache_dir=cache_dir,
31
+ )
32
+ case Datasets.fineweb_edu:
33
+ return load_dataset(
34
+ "HuggingFaceFW/fineweb-edu",
35
+ name="default",
36
+ split="train",
37
+ streaming=True,
38
+ cache_dir=cache_dir,
39
+ )
40
+ case Datasets.fineweb_edu_100b:
41
+ return load_dataset(
42
+ "HuggingFaceFW/fineweb-edu",
43
+ name="sample-100BT",
44
+ split="train",
45
+ streaming=True,
46
+ cache_dir=cache_dir,
47
+ )
48
+ case Datasets.fineweb_edu_dedup:
49
+ return load_dataset(
50
+ "HuggingFaceTB/smollm-corpus",
51
+ "fineweb-edu-dedup",
52
+ split="train",
53
+ streaming=True,
54
+ cache_dir=cache_dir,
55
+ )
56
+
57
+
58
+ class HFTextDataset(IterableDataset):
59
+ def __init__(
60
+ self,
61
+ dataset_name: Datasets | str,
62
+ split: str,
63
+ tokenizer: Callable[[str], list[int]],
64
+ seq_len: int,
65
+ infinite: bool = True,
66
+ cache_dir: str | None = None,
67
+ ):
68
+ if rank() != 0:
69
+ disable_progress_bars()
70
+ self.dataset_name = dataset_name
71
+ if dataset_name in Datasets:
72
+ ds = Datasets(dataset_name).load(cache_dir=cache_dir)
73
+
74
+ else:
75
+ logger.warning(f"Unknown dataset: {dataset_name}. Trying to use `load_dataset` directly.")
76
+ ds = load_dataset(dataset_name, split=split, streaming=True, cache_dir=cache_dir)
77
+
78
+ self.data = split_dataset_by_node(ds, rank=rank(), world_size=world_size())
79
+ self.tokenizer = tokenizer
80
+ self.seq_len = seq_len
81
+ self.token_buffer: list[int] = []
82
+
83
+ def _text_processor(
84
+ self,
85
+ sample: dict,
86
+ ) -> str:
87
+ # Default text processor: extract 'text' field
88
+ return sample["text"]
89
+
90
+ def __iter__(self):
91
+ max_buffer_token_len = 1 + self.seq_len
92
+
93
+ while True:
94
+ for sample in iter(self.data):
95
+ # Use the dataset-specific text processor
96
+ sample_text = self._text_processor(sample)
97
+ sample_tokens = self.tokenizer.encode(sample_text)
98
+ self.token_buffer.extend(sample_tokens)
99
+
100
+ while len(self.token_buffer) >= max_buffer_token_len:
101
+ x = torch.LongTensor(self.token_buffer[:max_buffer_token_len])
102
+ # update tokens to the remaining tokens
103
+ self.token_buffer = self.token_buffer[max_buffer_token_len:]
104
+ input = x[:-1]
105
+ label = x[1:]
106
+ yield {"input": input}, label
107
+
108
+ if not self.infinite:
109
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
110
+ break
111
+ else:
112
+ # Reset offset for the next iteration
113
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
114
+ if hasattr(self.data, "set_epoch") and hasattr(self.data, "epoch"):
115
+ self.data.set_epoch(self.data.epoch + 1)
@@ -0,0 +1,63 @@
1
+ import json
2
+ from pathlib import Path
3
+
4
+ from tokenizers import Tokenizer
5
+
6
+
7
+ class BaseTokenizerWrapper:
8
+ def encode(self, *args, **kwargs) -> list[int]:
9
+ raise NotImplementedError
10
+
11
+ def decode(self, *args, **kwargs) -> str:
12
+ raise NotImplementedError
13
+
14
+ def __len__(self) -> int:
15
+ raise NotImplementedError
16
+
17
+
18
+ class HFTokenizerWrapper(BaseTokenizerWrapper):
19
+ def __init__(
20
+ self,
21
+ tokenizer_path: Path,
22
+ ):
23
+ self.tokenizer = Tokenizer.from_file(str(tokenizer_path / "tokenizer.json"))
24
+ with (tokenizer_path / "tokenizer_config.json").open("r") as f:
25
+ config = json.load(f)
26
+
27
+ bos_token = self._get_tokens_from_config(config.get("bos_token", None))
28
+ if bos_token is None:
29
+ raise ValueError("BOS token must be specified in the tokenizer config.")
30
+
31
+ # check if tokenizer adds bos token automatically
32
+ test_encoding = self.tokenizer.encode("test").ids
33
+ self.bos_token_id = self.tokenizer.token_to_id(bos_token)
34
+ self.need_bos = self.bos_token_id not in test_encoding
35
+
36
+ def _get_tokens_from_config(
37
+ self,
38
+ token: dict[str, str] | str | None,
39
+ ) -> str | None:
40
+ if isinstance(token, dict):
41
+ token = token["content"]
42
+ return token
43
+
44
+ def encode(
45
+ self,
46
+ text: str,
47
+ ) -> list[int]:
48
+ token_ids = self.tokenizer.encode(text).ids
49
+
50
+ if self.need_bos:
51
+ token_ids = [self.bos_token_id] + token_ids
52
+
53
+ return token_ids
54
+
55
+ def decode(
56
+ self,
57
+ token_ids: list[int],
58
+ **kwargs,
59
+ ) -> str:
60
+ return self.tokenizer.decode(token_ids, **kwargs)
61
+
62
+ def __len__(self) -> int:
63
+ return self.tokenizer.get_vocab_size()