torchtextclassifiers 0.0.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. torchTextClassifiers/__init__.py +12 -48
  2. torchTextClassifiers/dataset/__init__.py +1 -0
  3. torchTextClassifiers/dataset/dataset.py +152 -0
  4. torchTextClassifiers/model/__init__.py +2 -0
  5. torchTextClassifiers/model/components/__init__.py +12 -0
  6. torchTextClassifiers/model/components/attention.py +126 -0
  7. torchTextClassifiers/model/components/categorical_var_net.py +128 -0
  8. torchTextClassifiers/model/components/classification_head.py +61 -0
  9. torchTextClassifiers/model/components/text_embedder.py +220 -0
  10. torchTextClassifiers/model/lightning.py +170 -0
  11. torchTextClassifiers/model/model.py +151 -0
  12. torchTextClassifiers/tokenizers/WordPiece.py +92 -0
  13. torchTextClassifiers/tokenizers/__init__.py +10 -0
  14. torchTextClassifiers/tokenizers/base.py +205 -0
  15. torchTextClassifiers/tokenizers/ngram.py +472 -0
  16. torchTextClassifiers/torchTextClassifiers.py +500 -413
  17. torchTextClassifiers/utilities/__init__.py +0 -3
  18. torchTextClassifiers/utilities/plot_explainability.py +184 -0
  19. torchtextclassifiers-1.0.0.dist-info/METADATA +87 -0
  20. torchtextclassifiers-1.0.0.dist-info/RECORD +21 -0
  21. {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-1.0.0.dist-info}/WHEEL +1 -1
  22. torchTextClassifiers/classifiers/base.py +0 -83
  23. torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
  24. torchTextClassifiers/classifiers/fasttext/core.py +0 -269
  25. torchTextClassifiers/classifiers/fasttext/model.py +0 -752
  26. torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
  27. torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
  28. torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
  29. torchTextClassifiers/factories.py +0 -34
  30. torchTextClassifiers/utilities/checkers.py +0 -108
  31. torchTextClassifiers/utilities/preprocess.py +0 -82
  32. torchTextClassifiers/utilities/utils.py +0 -346
  33. torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
  34. torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
@@ -0,0 +1,220 @@
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm
9
+
10
+
11
+ @dataclass
12
+ class TextEmbedderConfig:
13
+ vocab_size: int
14
+ embedding_dim: int
15
+ padding_idx: int
16
+ attention_config: Optional[AttentionConfig] = None
17
+
18
+
19
+ class TextEmbedder(nn.Module):
20
+ def __init__(self, text_embedder_config: TextEmbedderConfig):
21
+ super().__init__()
22
+
23
+ self.config = text_embedder_config
24
+
25
+ self.attention_config = text_embedder_config.attention_config
26
+ if self.attention_config is not None:
27
+ self.attention_config.n_embd = text_embedder_config.embedding_dim
28
+
29
+ self.vocab_size = text_embedder_config.vocab_size
30
+ self.embedding_dim = text_embedder_config.embedding_dim
31
+ self.padding_idx = text_embedder_config.padding_idx
32
+
33
+ self.embedding_layer = nn.Embedding(
34
+ embedding_dim=self.embedding_dim,
35
+ num_embeddings=self.vocab_size,
36
+ padding_idx=self.padding_idx,
37
+ )
38
+
39
+ if self.attention_config is not None:
40
+ self.transformer = nn.ModuleDict(
41
+ {
42
+ "h": nn.ModuleList(
43
+ [
44
+ Block(self.attention_config, layer_idx)
45
+ for layer_idx in range(self.attention_config.n_layers)
46
+ ]
47
+ ),
48
+ }
49
+ )
50
+
51
+ head_dim = self.attention_config.n_embd // self.attention_config.n_head
52
+
53
+ if head_dim * self.attention_config.n_head != self.attention_config.n_embd:
54
+ raise ValueError("embedding_dim must be divisible by n_head.")
55
+
56
+ if self.attention_config.positional_encoding:
57
+ if head_dim % 2 != 0:
58
+ raise ValueError(
59
+ "embedding_dim / n_head must be even for rotary positional embeddings."
60
+ )
61
+
62
+ if self.attention_config.sequence_len is None:
63
+ raise ValueError(
64
+ "sequence_len must be specified in AttentionConfig when positional_encoding is True."
65
+ )
66
+
67
+ self.rotary_seq_len = self.attention_config.sequence_len * 10
68
+ cos, sin = self._precompute_rotary_embeddings(
69
+ seq_len=self.rotary_seq_len, head_dim=head_dim
70
+ )
71
+
72
+ self.register_buffer(
73
+ "cos", cos, persistent=False
74
+ ) # persistent=False means it's not saved to the checkpoint
75
+ self.register_buffer("sin", sin, persistent=False)
76
+
77
+ def init_weights(self):
78
+ self.apply(self._init_weights)
79
+
80
+ # zero out c_proj weights in all blocks
81
+ if self.attention_config is not None:
82
+ for block in self.transformer.h:
83
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
84
+ torch.nn.init.zeros_(block.attn.c_proj.weight)
85
+ # init the rotary embeddings
86
+ head_dim = self.attention_config.n_embd // self.attention_config.n_head
87
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
88
+ self.cos, self.sin = cos, sin
89
+ # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
90
+ if self.embedding_layer.weight.device.type == "cuda":
91
+ self.embedding_layer.to(dtype=torch.bfloat16)
92
+
93
+ def _init_weights(self, module):
94
+ if isinstance(module, nn.Linear):
95
+ # https://arxiv.org/pdf/2310.17813
96
+ fan_out = module.weight.size(0)
97
+ fan_in = module.weight.size(1)
98
+ std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
99
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
100
+ if module.bias is not None:
101
+ torch.nn.init.zeros_(module.bias)
102
+ elif isinstance(module, nn.Embedding):
103
+ torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
104
+
105
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
106
+ """Converts input token IDs to their corresponding embeddings."""
107
+
108
+ encoded_text = input_ids # clearer name
109
+ if encoded_text.dtype != torch.long:
110
+ encoded_text = encoded_text.to(torch.long)
111
+
112
+ batch_size, seq_len = encoded_text.shape
113
+ batch_size_check, seq_len_check = attention_mask.shape
114
+
115
+ if batch_size != batch_size_check or seq_len != seq_len_check:
116
+ raise ValueError(
117
+ f"Input IDs and attention mask must have the same batch size and sequence length. "
118
+ f"Got input_ids shape {encoded_text.shape} and attention_mask shape {attention_mask.shape}."
119
+ )
120
+
121
+ token_embeddings = self.embedding_layer(
122
+ encoded_text
123
+ ) # (batch_size, seq_len, embedding_dim)
124
+
125
+ token_embeddings = norm(token_embeddings)
126
+
127
+ if self.attention_config is not None:
128
+ if self.attention_config.positional_encoding:
129
+ cos_sin = self.cos[:, :seq_len], self.sin[:, :seq_len]
130
+ else:
131
+ cos_sin = None
132
+
133
+ for block in self.transformer.h:
134
+ token_embeddings = block(token_embeddings, cos_sin)
135
+
136
+ token_embeddings = norm(token_embeddings)
137
+
138
+ text_embedding = self._get_sentence_embedding(
139
+ token_embeddings=token_embeddings, attention_mask=attention_mask
140
+ )
141
+
142
+ return text_embedding
143
+
144
+ def _get_sentence_embedding(
145
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
146
+ ) -> torch.Tensor:
147
+ """
148
+ Compute sentence embedding from embedded tokens - "remove" second dimension.
149
+
150
+ Args (output from dataset collate_fn):
151
+ token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text
152
+ attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
153
+ Returns:
154
+ torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim)
155
+ """
156
+
157
+ # average over non-pad token embeddings
158
+ # attention mask has 1 for non-pad tokens and 0 for pad token positions
159
+
160
+ # mask pad-tokens
161
+
162
+ if self.attention_config is not None:
163
+ if self.attention_config.aggregation_method is not None:
164
+ if self.attention_config.aggregation_method == "first":
165
+ return token_embeddings[:, 0, :]
166
+ elif self.attention_config.aggregation_method == "last":
167
+ lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
168
+ return token_embeddings[
169
+ torch.arange(token_embeddings.size(0)),
170
+ lengths - 1,
171
+ :,
172
+ ]
173
+ else:
174
+ if self.attention_config.aggregation_method != "mean":
175
+ raise ValueError(
176
+ f"Unknown aggregation method: {self.attention_config.aggregation_method}. Supported methods are 'mean', 'first', 'last'."
177
+ )
178
+
179
+ assert self.attention_config is None or self.attention_config.aggregation_method == "mean"
180
+
181
+ mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1)
182
+ masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
183
+
184
+ sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp(
185
+ min=1.0
186
+ ) # avoid division by zero
187
+
188
+ sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0)
189
+
190
+ return sentence_embedding
191
+
192
+ def __call__(self, *args, **kwargs):
193
+ out = super().__call__(*args, **kwargs)
194
+ if out.dim() != 2:
195
+ raise ValueError(
196
+ f"Output of {self.__class__.__name__}.forward must be 2D "
197
+ f"(got shape {tuple(out.shape)})"
198
+ )
199
+ return out
200
+
201
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
202
+ # autodetect the device from model embeddings
203
+ if device is None:
204
+ device = next(self.parameters()).device
205
+
206
+ # stride the channels
207
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
208
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
209
+ # stride the time steps
210
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
211
+ # calculate the rotation frequencies at each (time, channel) pair
212
+ freqs = torch.outer(t, inv_freq)
213
+ cos, sin = freqs.cos(), freqs.sin()
214
+ cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
215
+ cos, sin = (
216
+ cos[None, :, None, :],
217
+ sin[None, :, None, :],
218
+ ) # add batch and head dims for later broadcasting
219
+
220
+ return cos, sin
@@ -0,0 +1,170 @@
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from torchmetrics import Accuracy
4
+
5
+ from .model import TextClassificationModel
6
+
7
+ # ============================================================================
8
+ # PyTorch Lightning Module
9
+ # ============================================================================
10
+
11
+
12
+ class TextClassificationModule(pl.LightningModule):
13
+ """Pytorch Lightning Module for FastTextModel."""
14
+
15
+ def __init__(
16
+ self,
17
+ model: TextClassificationModel,
18
+ loss,
19
+ optimizer,
20
+ optimizer_params,
21
+ scheduler,
22
+ scheduler_params,
23
+ scheduler_interval="epoch",
24
+ **kwargs,
25
+ ):
26
+ """
27
+ Initialize FastTextModule.
28
+
29
+ Args:
30
+ model: Model.
31
+ loss: Loss
32
+ optimizer: Optimizer
33
+ optimizer_params: Optimizer parameters.
34
+ scheduler: Scheduler.
35
+ scheduler_params: Scheduler parameters.
36
+ scheduler_interval: Scheduler interval.
37
+ """
38
+ super().__init__()
39
+ self.save_hyperparameters(ignore=["model", "loss"])
40
+
41
+ self.model = model
42
+ self.loss = loss
43
+ self.accuracy_fn = Accuracy(task="multiclass", num_classes=self.model.num_classes)
44
+ self.optimizer = optimizer
45
+ self.optimizer_params = optimizer_params
46
+ self.scheduler = scheduler
47
+ self.scheduler_params = scheduler_params
48
+ self.scheduler_interval = scheduler_interval
49
+
50
+ def forward(self, batch) -> torch.Tensor:
51
+ """
52
+ Perform forward-pass.
53
+
54
+ Args:
55
+ batch (List[torch.LongTensor]): Batch to perform forward-pass on.
56
+
57
+ Returns (torch.Tensor): Prediction.
58
+ """
59
+ return self.model(
60
+ input_ids=batch["input_ids"],
61
+ attention_mask=batch["attention_mask"],
62
+ categorical_vars=batch.get("categorical_vars", None),
63
+ )
64
+
65
+ def training_step(self, batch, batch_idx: int) -> torch.Tensor:
66
+ """
67
+ Training step.
68
+
69
+ Args:
70
+ batch (List[torch.LongTensor]): Training batch.
71
+ batch_idx (int): Batch index.
72
+
73
+ Returns (torch.Tensor): Loss tensor.
74
+ """
75
+
76
+ targets = batch["labels"]
77
+
78
+ outputs = self.forward(batch)
79
+
80
+ if isinstance(self.loss, torch.nn.BCEWithLogitsLoss):
81
+ targets = targets.float()
82
+
83
+ loss = self.loss(outputs, targets)
84
+ self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True)
85
+ accuracy = self.accuracy_fn(outputs, targets)
86
+ self.log("train_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True)
87
+
88
+ torch.cuda.empty_cache()
89
+
90
+ return loss
91
+
92
+ def validation_step(self, batch, batch_idx: int):
93
+ """
94
+ Validation step.
95
+
96
+ Args:
97
+ batch (List[torch.LongTensor]): Validation batch.
98
+ batch_idx (int): Batch index.
99
+
100
+ Returns (torch.Tensor): Loss tensor.
101
+ """
102
+ targets = batch["labels"]
103
+
104
+ outputs = self.forward(batch)
105
+ loss = self.loss(outputs, targets)
106
+ self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
107
+
108
+ accuracy = self.accuracy_fn(outputs, targets)
109
+ self.log("val_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True)
110
+ return loss
111
+
112
+ def test_step(self, batch, batch_idx: int):
113
+ """
114
+ Test step.
115
+
116
+ Args:
117
+ batch (List[torch.LongTensor]): Test batch.
118
+ batch_idx (int): Batch index.
119
+
120
+ Returns (torch.Tensor): Loss tensor.
121
+ """
122
+ targets = batch["labels"]
123
+
124
+ outputs = self.forward(batch)
125
+ loss = self.loss(outputs, targets)
126
+
127
+ accuracy = self.accuracy_fn(outputs, targets)
128
+
129
+ return loss, accuracy
130
+
131
+ def predict_step(self, batch, batch_idx: int = 0, dataloader_idx: int = 0):
132
+ """
133
+ Prediction step.
134
+
135
+ Args:
136
+ batch (List[torch.LongTensor]): Prediction batch.
137
+ batch_idx (int): Batch index.
138
+ dataloader_idx (int): Dataloader index.
139
+
140
+ Returns (torch.Tensor): Predictions.
141
+ """
142
+ outputs = self.forward(batch)
143
+ return outputs
144
+
145
+ def configure_optimizers(self):
146
+ """
147
+ Configure optimizer for Pytorch lighting.
148
+
149
+ Returns: Optimizer and scheduler for pytorch lighting.
150
+ """
151
+ optimizer = self.optimizer(self.parameters(), **self.optimizer_params)
152
+
153
+ if self.scheduler is None:
154
+ return optimizer
155
+
156
+ # Only use scheduler if it's not ReduceLROnPlateau or if we can ensure val_loss is available
157
+ # For complex training setups, sometimes val_loss is not available every epoch
158
+ if hasattr(self.scheduler, "__name__") and "ReduceLROnPlateau" in self.scheduler.__name__:
159
+ # For ReduceLROnPlateau, use train_loss as it's always available
160
+ scheduler = self.scheduler(optimizer, **self.scheduler_params)
161
+ scheduler_config = {
162
+ "scheduler": scheduler,
163
+ "monitor": "train_loss",
164
+ "interval": self.scheduler_interval,
165
+ }
166
+ return [optimizer], [scheduler_config]
167
+ else:
168
+ # For other schedulers (StepLR, etc.), no monitoring needed
169
+ scheduler = self.scheduler(optimizer, **self.scheduler_params)
170
+ return [optimizer], [scheduler]
@@ -0,0 +1,151 @@
1
+ """FastText model components.
2
+
3
+ This module contains the PyTorch model, Lightning module, and dataset classes
4
+ for FastText classification. Consolidates what was previously in pytorch_model.py,
5
+ lightning_module.py, and dataset.py.
6
+ """
7
+
8
+ import logging
9
+ from typing import Annotated, Optional
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from torchTextClassifiers.model.components import (
15
+ CategoricalForwardType,
16
+ CategoricalVariableNet,
17
+ ClassificationHead,
18
+ TextEmbedder,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format="%(asctime)s - %(name)s - %(message)s",
26
+ datefmt="%Y-%m-%d %H:%M:%S",
27
+ handlers=[logging.StreamHandler()],
28
+ )
29
+
30
+
31
+ # ============================================================================
32
+ # PyTorch Model
33
+
34
+ # It takes PyTorch tensors as input (not raw text!),
35
+ # and it outputs raw not-softmaxed logits, not predictions
36
+ # ============================================================================
37
+
38
+
39
+ class TextClassificationModel(nn.Module):
40
+ """FastText Pytorch Model."""
41
+
42
+ def __init__(
43
+ self,
44
+ classification_head: ClassificationHead,
45
+ text_embedder: Optional[TextEmbedder] = None,
46
+ categorical_variable_net: Optional[CategoricalVariableNet] = None,
47
+ ):
48
+ """
49
+ Constructor for the FastTextModel class.
50
+
51
+ Args:
52
+ classification_head (ClassificationHead): The classification head module.
53
+ text_embedder (Optional[TextEmbedder]): The text embedding module.
54
+ If not provided, assumes that input text is already embedded (as tensors) and directly passed to the classification head.
55
+ categorical_variable_net (Optional[CategoricalVariableNet]): The categorical variable network module.
56
+ If not provided, assumes no categorical variables are used.
57
+ """
58
+ super().__init__()
59
+
60
+ self.text_embedder = text_embedder
61
+
62
+ self.categorical_variable_net = categorical_variable_net
63
+ if not self.categorical_variable_net:
64
+ logger.info("🔹 No categorical variable network provided; using only text embeddings.")
65
+
66
+ self.classification_head = classification_head
67
+
68
+ self._validate_component_connections()
69
+
70
+ self.num_classes = self.classification_head.num_classes
71
+
72
+ torch.nn.init.zeros_(self.classification_head.net.weight)
73
+ if self.text_embedder is not None:
74
+ self.text_embedder.init_weights()
75
+
76
+ def _validate_component_connections(self):
77
+ def _check_text_categorical_connection(self, text_embedder, cat_var_net):
78
+ if cat_var_net.forward_type == CategoricalForwardType.SUM_TO_TEXT:
79
+ if text_embedder.embedding_dim != cat_var_net.output_dim:
80
+ raise ValueError(
81
+ "Text embedding dimension must match categorical variable embedding dimension."
82
+ )
83
+ self.expected_classification_head_input_dim = text_embedder.embedding_dim
84
+ else:
85
+ self.expected_classification_head_input_dim = (
86
+ text_embedder.embedding_dim + cat_var_net.output_dim
87
+ )
88
+
89
+ if self.text_embedder:
90
+ if self.categorical_variable_net:
91
+ _check_text_categorical_connection(
92
+ self, self.text_embedder, self.categorical_variable_net
93
+ )
94
+ else:
95
+ self.expected_classification_head_input_dim = self.text_embedder.embedding_dim
96
+
97
+ if self.expected_classification_head_input_dim != self.classification_head.input_dim:
98
+ raise ValueError(
99
+ "Classification head input dimension does not match expected dimension from text embedder and categorical variable net."
100
+ )
101
+ else:
102
+ logger.warning(
103
+ "⚠️ No text embedder provided; assuming input text is already embedded or vectorized. Take care that the classification head input dimension matches the input text dimension."
104
+ )
105
+
106
+ def forward(
107
+ self,
108
+ input_ids: Annotated[torch.Tensor, "batch seq_len"],
109
+ attention_mask: Annotated[torch.Tensor, "batch seq_len"],
110
+ categorical_vars: Annotated[torch.Tensor, "batch num_cats"],
111
+ **kwargs,
112
+ ) -> torch.Tensor:
113
+ """
114
+ Memory-efficient forward pass implementation.
115
+
116
+ Args: output from dataset collate_fn
117
+ input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text
118
+ attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
119
+ categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features)
120
+
121
+ Returns:
122
+ torch.Tensor: Model output scores for each class - shape (batch_size, num_classes)
123
+ Raw, not softmaxed.
124
+ """
125
+ encoded_text = input_ids # clearer name
126
+ if self.text_embedder is None:
127
+ x_text = encoded_text.float()
128
+ else:
129
+ x_text = self.text_embedder(input_ids=encoded_text, attention_mask=attention_mask)
130
+
131
+ if self.categorical_variable_net:
132
+ x_cat = self.categorical_variable_net(categorical_vars)
133
+
134
+ if (
135
+ self.categorical_variable_net.forward_type
136
+ == CategoricalForwardType.AVERAGE_AND_CONCAT
137
+ or self.categorical_variable_net.forward_type
138
+ == CategoricalForwardType.CONCATENATE_ALL
139
+ ):
140
+ x_combined = torch.cat((x_text, x_cat), dim=1)
141
+ else:
142
+ assert (
143
+ self.categorical_variable_net.forward_type == CategoricalForwardType.SUM_TO_TEXT
144
+ )
145
+ x_combined = x_text + x_cat
146
+ else:
147
+ x_combined = x_text
148
+
149
+ logits = self.classification_head(x_combined)
150
+
151
+ return logits
@@ -0,0 +1,92 @@
1
+ import logging
2
+ import os
3
+ from typing import List, Optional
4
+
5
+ from torchTextClassifiers.tokenizers import HAS_HF, HuggingFaceTokenizer
6
+
7
+ if not HAS_HF:
8
+ raise ImportError(
9
+ "The HuggingFace dependencies are needed to use this tokenizer. Please run 'uv add torchTextClassifiers --extra huggingface."
10
+ )
11
+ else:
12
+ from tokenizers import (
13
+ Tokenizer,
14
+ decoders,
15
+ models,
16
+ normalizers,
17
+ pre_tokenizers,
18
+ processors,
19
+ trainers,
20
+ )
21
+ from transformers import PreTrainedTokenizerFast
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class WordPieceTokenizer(HuggingFaceTokenizer):
27
+ def __init__(self, vocab_size: int, trained: bool = False, output_dim: Optional[int] = None):
28
+ """Largely inspired by https://huggingface.co/learn/llm-course/chapter6/8"""
29
+
30
+ super().__init__(vocab_size=vocab_size, output_dim=output_dim)
31
+
32
+ self.unk_token = "[UNK]"
33
+ self.pad_token = "[PAD]"
34
+ self.cls_token = "[CLS]"
35
+ self.sep_token = "[SEP]"
36
+ self.special_tokens = [
37
+ self.unk_token,
38
+ self.pad_token,
39
+ self.cls_token,
40
+ self.sep_token,
41
+ ]
42
+ self.vocab_size = vocab_size
43
+ self.context_size = output_dim
44
+
45
+ self.tokenizer = Tokenizer(models.WordPiece(unk_token=self.unk_token))
46
+
47
+ self.tokenizer.normalizer = normalizers.BertNormalizer(
48
+ lowercase=True
49
+ ) # NFD, lowercase, strip accents - BERT style
50
+
51
+ self.tokenizer.pre_tokenizer = (
52
+ pre_tokenizers.BertPreTokenizer()
53
+ ) # split on whitespace and punctuation - BERT style
54
+ self.trained = trained
55
+
56
+ def _post_training(self):
57
+ if not self.trained:
58
+ raise RuntimeError(
59
+ "Tokenizer must be trained before applying post-training configurations."
60
+ )
61
+
62
+ self.tokenizer.post_processor = processors.BertProcessing(
63
+ (self.cls_token, self.tokenizer.token_to_id(self.cls_token)),
64
+ (self.sep_token, self.tokenizer.token_to_id(self.sep_token)),
65
+ )
66
+ self.tokenizer.decoder = decoders.WordPiece(prefix="##")
67
+ self.padding_idx = self.tokenizer.token_to_id("[PAD]")
68
+ self.tokenizer.enable_padding(pad_id=self.padding_idx, pad_token="[PAD]")
69
+
70
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer)
71
+ self.vocab_size = len(self.tokenizer)
72
+
73
+ def train(
74
+ self, training_corpus: List[str], save_path: str = None, filesystem=None, s3_save_path=None
75
+ ):
76
+ trainer = trainers.WordPieceTrainer(
77
+ vocab_size=self.vocab_size,
78
+ special_tokens=self.special_tokens,
79
+ )
80
+ self.tokenizer.train_from_iterator(training_corpus, trainer=trainer)
81
+ self.trained = True
82
+ self._post_training()
83
+
84
+ if save_path:
85
+ self.tokenizer.save(save_path)
86
+ logger.info(f"💾 Tokenizer saved at {save_path}")
87
+ if filesystem and s3_save_path:
88
+ parent_dir = os.path.dirname(save_path)
89
+ if not filesystem.exists(parent_dir):
90
+ filesystem.mkdirs(parent_dir)
91
+ filesystem.put(save_path, s3_save_path)
92
+ logger.info(f"💾 Tokenizer uploaded to S3 at {s3_save_path}")
@@ -0,0 +1,10 @@
1
+ from .base import (
2
+ HAS_HF as HAS_HF,
3
+ )
4
+ from .base import BaseTokenizer as BaseTokenizer
5
+ from .base import (
6
+ HuggingFaceTokenizer as HuggingFaceTokenizer,
7
+ )
8
+ from .base import TokenizerOutput as TokenizerOutput
9
+ from .ngram import NGramTokenizer as NGramTokenizer
10
+ from .WordPiece import WordPieceTokenizer as WordPieceTokenizer