torchtextclassifiers 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/__init__.py +12 -48
- torchTextClassifiers/dataset/__init__.py +1 -0
- torchTextClassifiers/dataset/dataset.py +114 -0
- torchTextClassifiers/model/__init__.py +2 -0
- torchTextClassifiers/model/components/__init__.py +12 -0
- torchTextClassifiers/model/components/attention.py +126 -0
- torchTextClassifiers/model/components/categorical_var_net.py +128 -0
- torchTextClassifiers/model/components/classification_head.py +43 -0
- torchTextClassifiers/model/components/text_embedder.py +220 -0
- torchTextClassifiers/model/lightning.py +166 -0
- torchTextClassifiers/model/model.py +151 -0
- torchTextClassifiers/tokenizers/WordPiece.py +92 -0
- torchTextClassifiers/tokenizers/__init__.py +10 -0
- torchTextClassifiers/tokenizers/base.py +205 -0
- torchTextClassifiers/tokenizers/ngram.py +472 -0
- torchTextClassifiers/torchTextClassifiers.py +463 -405
- torchTextClassifiers/utilities/__init__.py +0 -3
- torchTextClassifiers/utilities/plot_explainability.py +184 -0
- torchtextclassifiers-0.1.0.dist-info/METADATA +73 -0
- torchtextclassifiers-0.1.0.dist-info/RECORD +21 -0
- {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-0.1.0.dist-info}/WHEEL +1 -1
- torchTextClassifiers/classifiers/base.py +0 -83
- torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
- torchTextClassifiers/classifiers/fasttext/core.py +0 -269
- torchTextClassifiers/classifiers/fasttext/model.py +0 -752
- torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
- torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
- torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
- torchTextClassifiers/factories.py +0 -34
- torchTextClassifiers/utilities/checkers.py +0 -108
- torchTextClassifiers/utilities/preprocess.py +0 -82
- torchTextClassifiers/utilities/utils.py +0 -346
- torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
- torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import asdict, dataclass
|
|
3
|
+
from typing import Any, Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from tokenizers import Tokenizer
|
|
10
|
+
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
|
11
|
+
|
|
12
|
+
HAS_HF = True
|
|
13
|
+
except ImportError:
|
|
14
|
+
HAS_HF = False
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class TokenizerOutput:
|
|
19
|
+
input_ids: torch.Tensor # shape: (batch_size, seq_len)
|
|
20
|
+
attention_mask: torch.Tensor # shape: (batch_size, seq_len)
|
|
21
|
+
offset_mapping: Optional[torch.Tensor] = None # shape: (batch_size, seq_len, 2)
|
|
22
|
+
word_ids: Optional[np.ndarray] = None # shape: (batch_size, seq_len)
|
|
23
|
+
|
|
24
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
25
|
+
return asdict(self)
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_dict(cls, data: Dict[str, Any]) -> "TokenizerOutput":
|
|
29
|
+
return cls(**data)
|
|
30
|
+
|
|
31
|
+
def __post_init__(self):
|
|
32
|
+
# --- Basic type checks ---
|
|
33
|
+
if not isinstance(self.input_ids, torch.Tensor):
|
|
34
|
+
raise TypeError(f"token_ids must be a torch.Tensor, got {type(self.input_ids)}")
|
|
35
|
+
if not isinstance(self.attention_mask, torch.Tensor):
|
|
36
|
+
raise TypeError(
|
|
37
|
+
f"attention_mask must be a torch.Tensor, got {type(self.attention_mask)}"
|
|
38
|
+
)
|
|
39
|
+
if self.offset_mapping is not None and not isinstance(self.offset_mapping, torch.Tensor):
|
|
40
|
+
raise TypeError(
|
|
41
|
+
f"offset_mapping must be a torch.Tensor or None, got {type(self.offset_mapping)}"
|
|
42
|
+
)
|
|
43
|
+
if self.word_ids is not None and not isinstance(self.word_ids, np.ndarray):
|
|
44
|
+
raise TypeError(f"word_ids must be a numpy.ndarray or None, got {type(self.word_ids)}")
|
|
45
|
+
|
|
46
|
+
# --- Shape consistency checks ---
|
|
47
|
+
if self.input_ids.shape != self.attention_mask.shape:
|
|
48
|
+
raise ValueError(
|
|
49
|
+
f"Shape mismatch: token_ids {self.token_ids.shape} and attention_mask {self.attention_mask.shape}"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if self.offset_mapping is not None:
|
|
53
|
+
expected_shape = (*self.input_ids.shape, 2)
|
|
54
|
+
if self.offset_mapping.shape != expected_shape:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"offset_mapping should have shape {expected_shape}, got {self.offset_mapping.shape}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if self.word_ids is not None:
|
|
60
|
+
if self.word_ids.shape != self.input_ids.shape:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"word_ids should have shape {self.input_ids.shape}, got {self.word_ids.shape}"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class BaseTokenizer(ABC):
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
vocab_size: int,
|
|
70
|
+
padding_idx: int,
|
|
71
|
+
output_vectorized: bool = False,
|
|
72
|
+
output_dim: Optional[int] = None,
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
Base class for tokenizers.
|
|
76
|
+
Args:
|
|
77
|
+
vocab_size (int): Size of the vocabulary.
|
|
78
|
+
output_vectorized (bool): Whether the tokenizer outputs vectorized tokens.
|
|
79
|
+
True for instance for a TF-IDF tokenizer.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
self.vocab_size = vocab_size
|
|
83
|
+
self.output_vectorized = output_vectorized
|
|
84
|
+
self.output_dim = output_dim
|
|
85
|
+
self.padding_idx = padding_idx
|
|
86
|
+
if self.output_vectorized:
|
|
87
|
+
if output_dim is None:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
"Tokenizer's output_dim must be provided if output_vectorized is True."
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def tokenize(self, text: Union[str, List[str]]) -> TokenizerOutput:
|
|
94
|
+
"""Tokenizes the raw input text into a list of tokens."""
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
def __len__(self):
|
|
98
|
+
return self.vocab_size
|
|
99
|
+
|
|
100
|
+
def __repr__(self):
|
|
101
|
+
return f"{self.__class__.__name__}(vocab_size={self.vocab_size}, output_vectorized={self.output_vectorized}, output_dim={self.output_dim})"
|
|
102
|
+
|
|
103
|
+
def __call__(self, text: Union[str, List[str]], **kwargs) -> list:
|
|
104
|
+
return self.tokenize(text, **kwargs)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class HuggingFaceTokenizer(BaseTokenizer):
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
vocab_size: int,
|
|
111
|
+
output_dim: Optional[int] = None,
|
|
112
|
+
padding_idx: Optional[int] = None,
|
|
113
|
+
trained: bool = False,
|
|
114
|
+
):
|
|
115
|
+
super().__init__(
|
|
116
|
+
vocab_size, output_vectorized=False, output_dim=output_dim, padding_idx=padding_idx
|
|
117
|
+
) # it outputs token ids and not vectors
|
|
118
|
+
|
|
119
|
+
self.trained = trained
|
|
120
|
+
self.tokenizer = None
|
|
121
|
+
self.padding_idx = padding_idx
|
|
122
|
+
self.output_dim = output_dim # constant context size for all batch
|
|
123
|
+
|
|
124
|
+
def tokenize(
|
|
125
|
+
self,
|
|
126
|
+
text: Union[str, List[str]],
|
|
127
|
+
return_offsets_mapping: Optional[bool] = False,
|
|
128
|
+
return_word_ids: Optional[bool] = False,
|
|
129
|
+
) -> list:
|
|
130
|
+
if not self.trained:
|
|
131
|
+
raise RuntimeError("Tokenizer must be trained before tokenization.")
|
|
132
|
+
|
|
133
|
+
# Pad to longest sequence if no output_dim is specified
|
|
134
|
+
padding = True if self.output_dim is None else "max_length"
|
|
135
|
+
truncation = True if self.output_dim is not None else False
|
|
136
|
+
|
|
137
|
+
tokenize_output = self.tokenizer(
|
|
138
|
+
text,
|
|
139
|
+
padding=padding,
|
|
140
|
+
return_tensors="pt",
|
|
141
|
+
truncation=truncation,
|
|
142
|
+
max_length=self.output_dim,
|
|
143
|
+
return_offsets_mapping=return_offsets_mapping,
|
|
144
|
+
) # method from PreTrainedTokenizerFast
|
|
145
|
+
|
|
146
|
+
encoded_text = tokenize_output["input_ids"]
|
|
147
|
+
|
|
148
|
+
if return_word_ids:
|
|
149
|
+
word_ids = np.array([tokenize_output.word_ids(i) for i in range(len(encoded_text))])
|
|
150
|
+
else:
|
|
151
|
+
word_ids = None
|
|
152
|
+
|
|
153
|
+
return TokenizerOutput(
|
|
154
|
+
input_ids=encoded_text,
|
|
155
|
+
attention_mask=tokenize_output["attention_mask"],
|
|
156
|
+
offset_mapping=tokenize_output.get("offset_mapping", None),
|
|
157
|
+
word_ids=word_ids,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
def load_from_pretrained(cls, tokenizer_name: str, output_dim: Optional[int] = None):
|
|
162
|
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
163
|
+
padding_idx = tokenizer.pad_token_id
|
|
164
|
+
instance = cls(
|
|
165
|
+
vocab_size=len(tokenizer), trained=True, padding_idx=padding_idx, output_dim=output_dim
|
|
166
|
+
)
|
|
167
|
+
instance.tokenizer = tokenizer
|
|
168
|
+
return instance
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def load(cls, load_path: str):
|
|
172
|
+
loaded_tokenizer = PreTrainedTokenizerFast(tokenizer_file=load_path)
|
|
173
|
+
instance = cls(vocab_size=len(loaded_tokenizer), trained=True)
|
|
174
|
+
instance.tokenizer = loaded_tokenizer
|
|
175
|
+
# instance._post_training()
|
|
176
|
+
return instance
|
|
177
|
+
|
|
178
|
+
@classmethod
|
|
179
|
+
def load_from_s3(cls, s3_path: str, filesystem):
|
|
180
|
+
if filesystem.exists(s3_path) is False:
|
|
181
|
+
raise FileNotFoundError(
|
|
182
|
+
f"Tokenizer not found at {s3_path}. Please train it first (see src/train_tokenizers)."
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
with filesystem.open(s3_path, "rb") as f:
|
|
186
|
+
json_str = f.read().decode("utf-8")
|
|
187
|
+
|
|
188
|
+
tokenizer_obj = Tokenizer.from_str(json_str)
|
|
189
|
+
tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_obj)
|
|
190
|
+
instance = cls(vocab_size=len(tokenizer), trained=True)
|
|
191
|
+
instance.tokenizer = tokenizer
|
|
192
|
+
instance._post_training()
|
|
193
|
+
return instance
|
|
194
|
+
|
|
195
|
+
def train(self, *args, **kwargs):
|
|
196
|
+
raise NotImplementedError(
|
|
197
|
+
"This tokenizer cannot be trained directly. "
|
|
198
|
+
"Load it from pretrained or implement train() in a subclass."
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def _post_training(self):
|
|
202
|
+
raise NotImplementedError("_post_training() not implemented for HuggingFaceTokenizer.")
|
|
203
|
+
|
|
204
|
+
def __repr__(self):
|
|
205
|
+
return f"{self.__class__.__name__} \n HuggingFace tokenizer: {self.tokenizer.__repr__()}"
|
|
@@ -0,0 +1,472 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
import unicodedata
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput
|
|
11
|
+
|
|
12
|
+
# ============================================================================
|
|
13
|
+
# Optimized normalization
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
_fasttext_non_alnum = re.compile(r"[^a-z0-9]+")
|
|
17
|
+
_fasttext_multi_space = re.compile(r"\s+")
|
|
18
|
+
|
|
19
|
+
# Pre-compile translation table for faster character removal
|
|
20
|
+
_COMBINING_MARKS = {c: None for c in range(0x0300, 0x0370)}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@lru_cache(maxsize=10000)
|
|
24
|
+
def _clean_single_text_cached(text: str) -> str:
|
|
25
|
+
"""Cached version of text cleaning - major speedup for repeated texts."""
|
|
26
|
+
t = text.lower()
|
|
27
|
+
t = unicodedata.normalize("NFKD", t)
|
|
28
|
+
# Faster: use translate() instead of list comprehension
|
|
29
|
+
t = t.translate(_COMBINING_MARKS)
|
|
30
|
+
t = _fasttext_non_alnum.sub(" ", t)
|
|
31
|
+
t = _fasttext_multi_space.sub(" ", t)
|
|
32
|
+
return t.strip()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def clean_text_feature(texts: List[str]) -> List[str]:
|
|
36
|
+
"""Vectorized text cleaning with caching."""
|
|
37
|
+
return [_clean_single_text_cached(t) for t in texts]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# ============================================================================
|
|
41
|
+
# Optimized hash function
|
|
42
|
+
# ============================================================================
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def fast_hash(s: str) -> int:
|
|
46
|
+
"""FNV-1a hash - simple and fast."""
|
|
47
|
+
h = 2166136261
|
|
48
|
+
for c in s:
|
|
49
|
+
h ^= ord(c)
|
|
50
|
+
h = (h * 16777619) & 0xFFFFFFFF
|
|
51
|
+
return h
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ============================================================================
|
|
55
|
+
# Pre-computed subword cache
|
|
56
|
+
# ============================================================================
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SubwordCache:
|
|
60
|
+
"""Aggressive pre-computation cache for subwords."""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
word_to_id: dict,
|
|
65
|
+
min_n: int,
|
|
66
|
+
max_n: int,
|
|
67
|
+
num_tokens: int,
|
|
68
|
+
nwords: int,
|
|
69
|
+
unk_token_id: int,
|
|
70
|
+
):
|
|
71
|
+
self.cache = {}
|
|
72
|
+
self.word_to_id = word_to_id
|
|
73
|
+
self.min_n = min_n
|
|
74
|
+
self.max_n = max_n
|
|
75
|
+
self.num_tokens = num_tokens
|
|
76
|
+
self.nwords = nwords
|
|
77
|
+
self.unk_token_id = unk_token_id
|
|
78
|
+
|
|
79
|
+
# Pre-compute for all vocabulary words
|
|
80
|
+
self._precompute_vocab()
|
|
81
|
+
|
|
82
|
+
def _precompute_vocab(self):
|
|
83
|
+
"""Pre-compute subwords for entire vocabulary."""
|
|
84
|
+
for word, word_id in self.word_to_id.items():
|
|
85
|
+
self.cache[word] = self._compute_subwords(word, word_id)
|
|
86
|
+
|
|
87
|
+
def _compute_subwords(self, word: str, word_id: Optional[int] = None) -> List[int]:
|
|
88
|
+
"""Compute subword indices for a word."""
|
|
89
|
+
indices = []
|
|
90
|
+
|
|
91
|
+
# Add word token if in vocab
|
|
92
|
+
if word_id is not None:
|
|
93
|
+
indices.append(word_id)
|
|
94
|
+
|
|
95
|
+
# Extract character n-grams
|
|
96
|
+
word_tagged = f"<{word}>"
|
|
97
|
+
L = len(word_tagged)
|
|
98
|
+
|
|
99
|
+
for n in range(self.min_n, self.max_n + 1):
|
|
100
|
+
for i in range(L - n + 1):
|
|
101
|
+
ngram = word_tagged[i : i + n]
|
|
102
|
+
if ngram != word and ngram != word_tagged:
|
|
103
|
+
bucket_idx = fast_hash(ngram) % self.num_tokens
|
|
104
|
+
indices.append(3 + self.nwords + bucket_idx)
|
|
105
|
+
|
|
106
|
+
return indices if indices else [self.unk_token_id]
|
|
107
|
+
|
|
108
|
+
def get(self, word: str) -> List[int]:
|
|
109
|
+
"""Get subwords with on-demand computation for OOV words."""
|
|
110
|
+
if word not in self.cache:
|
|
111
|
+
word_id = self.word_to_id.get(word)
|
|
112
|
+
self.cache[word] = self._compute_subwords(word, word_id)
|
|
113
|
+
return self.cache[word]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# ============================================================================
|
|
117
|
+
# Vectorized encoding with optional metadata
|
|
118
|
+
# ============================================================================
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def encode_batch_vectorized(
|
|
122
|
+
sentences: List[str],
|
|
123
|
+
subword_cache: SubwordCache,
|
|
124
|
+
eos_token_id: int,
|
|
125
|
+
pad_token_id: int,
|
|
126
|
+
max_length: Optional[int] = None,
|
|
127
|
+
truncation: bool = False,
|
|
128
|
+
return_offsets_mapping: bool = False,
|
|
129
|
+
return_word_ids: bool = False,
|
|
130
|
+
force_max_length: bool = False,
|
|
131
|
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[List], Optional[List]]:
|
|
132
|
+
"""
|
|
133
|
+
Vectorized batch encoding - processes all sentences together.
|
|
134
|
+
Returns padded tensors directly, with optional offset mappings and word IDs.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
force_max_length: If True and max_length is set, always return tensors of size max_length
|
|
138
|
+
"""
|
|
139
|
+
all_ids = []
|
|
140
|
+
all_offsets = [] if return_offsets_mapping else None
|
|
141
|
+
all_word_ids = [] if return_word_ids else None
|
|
142
|
+
max_len = 0
|
|
143
|
+
|
|
144
|
+
# First pass: encode all sentences
|
|
145
|
+
for sentence in sentences:
|
|
146
|
+
ids = []
|
|
147
|
+
offsets = [] if return_offsets_mapping else None
|
|
148
|
+
word_ids = [] if return_word_ids else None
|
|
149
|
+
|
|
150
|
+
words = sentence.split()
|
|
151
|
+
char_offset = 0
|
|
152
|
+
|
|
153
|
+
for word_idx, word in enumerate(words):
|
|
154
|
+
# Find the actual position of this word in the original sentence
|
|
155
|
+
word_start = sentence.find(word, char_offset)
|
|
156
|
+
word_end = word_start + len(word)
|
|
157
|
+
char_offset = word_end
|
|
158
|
+
|
|
159
|
+
# Get subword tokens for this word
|
|
160
|
+
subword_tokens = subword_cache.get(word)
|
|
161
|
+
|
|
162
|
+
for token_id in subword_tokens:
|
|
163
|
+
ids.append(token_id)
|
|
164
|
+
|
|
165
|
+
if return_offsets_mapping:
|
|
166
|
+
# All subword tokens of a word map to the word's character span
|
|
167
|
+
offsets.append((word_start, word_end))
|
|
168
|
+
|
|
169
|
+
if return_word_ids:
|
|
170
|
+
# All subword tokens of a word get the same word_id
|
|
171
|
+
word_ids.append(word_idx)
|
|
172
|
+
|
|
173
|
+
# Add EOS token
|
|
174
|
+
ids.append(eos_token_id)
|
|
175
|
+
if return_offsets_mapping:
|
|
176
|
+
offsets.append((len(sentence), len(sentence))) # EOS has no span
|
|
177
|
+
if return_word_ids:
|
|
178
|
+
word_ids.append(None) # EOS is not part of any word
|
|
179
|
+
|
|
180
|
+
# Truncate if needed
|
|
181
|
+
if truncation and max_length and len(ids) > max_length:
|
|
182
|
+
ids = ids[:max_length]
|
|
183
|
+
if return_offsets_mapping:
|
|
184
|
+
offsets = offsets[:max_length]
|
|
185
|
+
if return_word_ids:
|
|
186
|
+
word_ids = word_ids[:max_length]
|
|
187
|
+
|
|
188
|
+
all_ids.append(ids)
|
|
189
|
+
if return_offsets_mapping:
|
|
190
|
+
all_offsets.append(offsets)
|
|
191
|
+
if return_word_ids:
|
|
192
|
+
all_word_ids.append(word_ids)
|
|
193
|
+
max_len = max(max_len, len(ids))
|
|
194
|
+
|
|
195
|
+
# Determine final sequence length
|
|
196
|
+
if force_max_length and max_length:
|
|
197
|
+
# Always use max_length when force_max_length is True
|
|
198
|
+
seq_len = max_length
|
|
199
|
+
elif max_length and not truncation:
|
|
200
|
+
seq_len = min(max_len, max_length)
|
|
201
|
+
elif max_length:
|
|
202
|
+
seq_len = max_length
|
|
203
|
+
else:
|
|
204
|
+
seq_len = max_len
|
|
205
|
+
|
|
206
|
+
# Pre-allocate tensors
|
|
207
|
+
batch_size = len(sentences)
|
|
208
|
+
input_ids = torch.full((batch_size, seq_len), pad_token_id, dtype=torch.long)
|
|
209
|
+
attention_mask = torch.zeros((batch_size, seq_len), dtype=torch.long)
|
|
210
|
+
|
|
211
|
+
# Fill tensors and pad metadata
|
|
212
|
+
for i, ids in enumerate(all_ids):
|
|
213
|
+
length = min(len(ids), seq_len)
|
|
214
|
+
input_ids[i, :length] = torch.tensor(ids[:length], dtype=torch.long)
|
|
215
|
+
attention_mask[i, :length] = 1
|
|
216
|
+
|
|
217
|
+
# Pad offsets and word_ids to match sequence length
|
|
218
|
+
if return_offsets_mapping:
|
|
219
|
+
# Pad with (0, 0) for padding tokens
|
|
220
|
+
all_offsets[i] = all_offsets[i][:length] + [(0, 0)] * (seq_len - length)
|
|
221
|
+
|
|
222
|
+
if return_word_ids:
|
|
223
|
+
# Pad with None for padding tokens
|
|
224
|
+
all_word_ids[i] = all_word_ids[i][:length] + [None] * (seq_len - length)
|
|
225
|
+
|
|
226
|
+
return input_ids, attention_mask, all_offsets, all_word_ids
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
# ============================================================================
|
|
230
|
+
# NGramTokenizer - Optimized
|
|
231
|
+
# ============================================================================
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class NGramTokenizer(BaseTokenizer):
|
|
235
|
+
"""
|
|
236
|
+
Heavily optimized FastText N-gram tokenizer with:
|
|
237
|
+
- Pre-computed subword cache for entire vocabulary
|
|
238
|
+
- Vectorized batch encoding
|
|
239
|
+
- Cached text normalization
|
|
240
|
+
- Direct tensor operations
|
|
241
|
+
- Optional offset mapping and word ID tracking
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
PAD_TOKEN = "[PAD]"
|
|
245
|
+
UNK_TOKEN = "[UNK]"
|
|
246
|
+
EOS_TOKEN = "</s>"
|
|
247
|
+
|
|
248
|
+
def __init__(
|
|
249
|
+
self,
|
|
250
|
+
min_count: int,
|
|
251
|
+
min_n: int,
|
|
252
|
+
max_n: int,
|
|
253
|
+
num_tokens: int,
|
|
254
|
+
len_word_ngrams: int,
|
|
255
|
+
training_text: Optional[List[str]] = None,
|
|
256
|
+
preprocess: bool = True,
|
|
257
|
+
output_dim: Optional[int] = None,
|
|
258
|
+
**kwargs,
|
|
259
|
+
):
|
|
260
|
+
if min_n < 2:
|
|
261
|
+
raise ValueError("min_n must be >= 2")
|
|
262
|
+
if max_n > 6:
|
|
263
|
+
raise ValueError("max_n must be <= 6")
|
|
264
|
+
|
|
265
|
+
self.min_count = min_count
|
|
266
|
+
self.min_n = min_n
|
|
267
|
+
self.max_n = max_n
|
|
268
|
+
self.num_tokens = num_tokens
|
|
269
|
+
self.word_ngrams = len_word_ngrams
|
|
270
|
+
self.preprocess = preprocess
|
|
271
|
+
|
|
272
|
+
self.pad_token_id = 0
|
|
273
|
+
self.unk_token_id = 1
|
|
274
|
+
self.eos_token_id = 2
|
|
275
|
+
|
|
276
|
+
if training_text is not None:
|
|
277
|
+
self.train(training_text)
|
|
278
|
+
else:
|
|
279
|
+
self.word_to_id = {}
|
|
280
|
+
self.id_to_word = {}
|
|
281
|
+
self.nwords = 0
|
|
282
|
+
self.subword_cache = None
|
|
283
|
+
|
|
284
|
+
self.vocab_size = 3 + self.nwords + self.num_tokens
|
|
285
|
+
|
|
286
|
+
super().__init__(
|
|
287
|
+
vocab_size=self.vocab_size, padding_idx=self.pad_token_id, output_dim=output_dim
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
def train(self, training_text: List[str]):
|
|
291
|
+
"""Build vocabulary from training text."""
|
|
292
|
+
word_counts = {}
|
|
293
|
+
for sent in training_text:
|
|
294
|
+
for w in sent.split():
|
|
295
|
+
word_counts[w] = word_counts.get(w, 0) + 1
|
|
296
|
+
|
|
297
|
+
self.word_to_id = {}
|
|
298
|
+
idx = 3
|
|
299
|
+
for w, c in word_counts.items():
|
|
300
|
+
if c >= self.min_count:
|
|
301
|
+
self.word_to_id[w] = idx
|
|
302
|
+
idx += 1
|
|
303
|
+
|
|
304
|
+
self.nwords = len(self.word_to_id)
|
|
305
|
+
self.vocab_size = 3 + self.nwords + self.num_tokens
|
|
306
|
+
|
|
307
|
+
# Create reverse mapping
|
|
308
|
+
self.id_to_word = {v: k for k, v in self.word_to_id.items()}
|
|
309
|
+
self.id_to_word[self.pad_token_id] = self.PAD_TOKEN
|
|
310
|
+
self.id_to_word[self.unk_token_id] = self.UNK_TOKEN
|
|
311
|
+
self.id_to_word[self.eos_token_id] = self.EOS_TOKEN
|
|
312
|
+
|
|
313
|
+
# Pre-compute all subwords for vocabulary
|
|
314
|
+
print(f"Pre-computing subwords for {self.nwords} vocabulary words...")
|
|
315
|
+
self.subword_cache = SubwordCache(
|
|
316
|
+
self.word_to_id, self.min_n, self.max_n, self.num_tokens, self.nwords, self.unk_token_id
|
|
317
|
+
)
|
|
318
|
+
print("✓ Subword cache built")
|
|
319
|
+
|
|
320
|
+
def tokenize(
|
|
321
|
+
self,
|
|
322
|
+
text: Union[str, List[str]],
|
|
323
|
+
return_offsets_mapping: bool = False,
|
|
324
|
+
return_word_ids: bool = False,
|
|
325
|
+
**kwargs,
|
|
326
|
+
) -> TokenizerOutput:
|
|
327
|
+
"""
|
|
328
|
+
Optimized tokenization with vectorized operations.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
text: Single string or list of strings to tokenize
|
|
332
|
+
padding: Padding strategy ('longest' or 'max_length')
|
|
333
|
+
max_length: Maximum sequence length
|
|
334
|
+
truncation: Whether to truncate sequences exceeding max_length
|
|
335
|
+
return_offsets_mapping: If True, return character offsets for each token
|
|
336
|
+
return_word_ids: If True, return word indices for each token
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
TokenizerOutput with input_ids, attention_mask, and optionally
|
|
340
|
+
offset_mapping and word_ids
|
|
341
|
+
"""
|
|
342
|
+
is_single = isinstance(text, str)
|
|
343
|
+
if is_single:
|
|
344
|
+
text = [text]
|
|
345
|
+
|
|
346
|
+
# Fast cached text cleaning
|
|
347
|
+
if self.preprocess:
|
|
348
|
+
text = clean_text_feature(text)
|
|
349
|
+
|
|
350
|
+
if self.output_dim is not None:
|
|
351
|
+
max_length = self.output_dim
|
|
352
|
+
truncation = True
|
|
353
|
+
else:
|
|
354
|
+
max_length = None
|
|
355
|
+
truncation = False
|
|
356
|
+
|
|
357
|
+
# Vectorized encoding
|
|
358
|
+
input_ids, attention_mask, offsets, word_ids = encode_batch_vectorized(
|
|
359
|
+
text,
|
|
360
|
+
self.subword_cache,
|
|
361
|
+
self.eos_token_id,
|
|
362
|
+
self.pad_token_id,
|
|
363
|
+
max_length=max_length,
|
|
364
|
+
truncation=truncation,
|
|
365
|
+
return_offsets_mapping=return_offsets_mapping,
|
|
366
|
+
return_word_ids=return_word_ids,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
offsets = torch.tensor(offsets) if return_offsets_mapping else None
|
|
370
|
+
word_ids = np.array(word_ids) if return_word_ids else None
|
|
371
|
+
|
|
372
|
+
return TokenizerOutput(
|
|
373
|
+
input_ids=input_ids,
|
|
374
|
+
attention_mask=attention_mask,
|
|
375
|
+
word_ids=word_ids,
|
|
376
|
+
offset_mapping=offsets,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
def decode(
|
|
380
|
+
self, token_ids: Union[List[int], torch.Tensor], skip_special_tokens: bool = True
|
|
381
|
+
) -> str:
|
|
382
|
+
"""Decode token IDs back to text."""
|
|
383
|
+
if isinstance(token_ids, torch.Tensor):
|
|
384
|
+
token_ids = token_ids.tolist()
|
|
385
|
+
|
|
386
|
+
tokens = []
|
|
387
|
+
for id_ in token_ids:
|
|
388
|
+
if id_ == self.pad_token_id and skip_special_tokens:
|
|
389
|
+
continue
|
|
390
|
+
|
|
391
|
+
if id_ == self.eos_token_id:
|
|
392
|
+
if not skip_special_tokens:
|
|
393
|
+
tokens.append(self.EOS_TOKEN)
|
|
394
|
+
continue
|
|
395
|
+
|
|
396
|
+
if id_ in self.id_to_word:
|
|
397
|
+
tokens.append(self.id_to_word[id_])
|
|
398
|
+
elif not skip_special_tokens:
|
|
399
|
+
tokens.append(f"[ID:{id_}]")
|
|
400
|
+
|
|
401
|
+
return " ".join(tokens)
|
|
402
|
+
|
|
403
|
+
def batch_decode(
|
|
404
|
+
self, sequences: Union[List[List[int]], torch.Tensor], skip_special_tokens: bool = True
|
|
405
|
+
) -> List[str]:
|
|
406
|
+
"""Decode multiple sequences."""
|
|
407
|
+
if isinstance(sequences, torch.Tensor):
|
|
408
|
+
sequences = sequences.tolist()
|
|
409
|
+
return [self.decode(seq, skip_special_tokens) for seq in sequences]
|
|
410
|
+
|
|
411
|
+
def save_pretrained(self, save_directory: str):
|
|
412
|
+
"""Save tokenizer configuration and vocabulary."""
|
|
413
|
+
import os
|
|
414
|
+
|
|
415
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
416
|
+
|
|
417
|
+
config = {
|
|
418
|
+
"min_count": self.min_count,
|
|
419
|
+
"min_n": self.min_n,
|
|
420
|
+
"max_n": self.max_n,
|
|
421
|
+
"num_tokens": self.num_tokens,
|
|
422
|
+
"len_word_ngrams": self.word_ngrams,
|
|
423
|
+
"word_to_id": self.word_to_id,
|
|
424
|
+
"preprocess": self.preprocess,
|
|
425
|
+
"vocab_size": self.vocab_size,
|
|
426
|
+
"nwords": self.nwords,
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
with open(f"{save_directory}/tokenizer.json", "w") as f:
|
|
430
|
+
json.dump(config, f, indent=2)
|
|
431
|
+
|
|
432
|
+
print(f"✓ Tokenizer saved to {save_directory}")
|
|
433
|
+
|
|
434
|
+
@classmethod
|
|
435
|
+
def from_pretrained(cls, directory: str):
|
|
436
|
+
"""Load tokenizer from saved configuration."""
|
|
437
|
+
with open(f"{directory}/tokenizer.json", "r") as f:
|
|
438
|
+
config = json.load(f)
|
|
439
|
+
|
|
440
|
+
tokenizer = cls(
|
|
441
|
+
min_count=config["min_count"],
|
|
442
|
+
min_n=config["min_n"],
|
|
443
|
+
max_n=config["max_n"],
|
|
444
|
+
num_tokens=config["num_tokens"],
|
|
445
|
+
len_word_ngrams=config["len_word_ngrams"],
|
|
446
|
+
preprocess=config["preprocess"],
|
|
447
|
+
training_text=None,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
tokenizer.word_to_id = config["word_to_id"]
|
|
451
|
+
tokenizer.nwords = config["nwords"]
|
|
452
|
+
tokenizer.vocab_size = config["vocab_size"]
|
|
453
|
+
|
|
454
|
+
tokenizer.id_to_word = {v: k for k, v in tokenizer.word_to_id.items()}
|
|
455
|
+
tokenizer.id_to_word[tokenizer.pad_token_id] = cls.PAD_TOKEN
|
|
456
|
+
tokenizer.id_to_word[tokenizer.unk_token_id] = cls.UNK_TOKEN
|
|
457
|
+
tokenizer.id_to_word[tokenizer.eos_token_id] = cls.EOS_TOKEN
|
|
458
|
+
|
|
459
|
+
# Rebuild subword cache
|
|
460
|
+
print("Rebuilding subword cache...")
|
|
461
|
+
tokenizer.subword_cache = SubwordCache(
|
|
462
|
+
tokenizer.word_to_id,
|
|
463
|
+
tokenizer.min_n,
|
|
464
|
+
tokenizer.max_n,
|
|
465
|
+
tokenizer.num_tokens,
|
|
466
|
+
tokenizer.nwords,
|
|
467
|
+
tokenizer.unk_token_id,
|
|
468
|
+
)
|
|
469
|
+
print("✓ Subword cache built")
|
|
470
|
+
|
|
471
|
+
print(f"✓ Tokenizer loaded from {directory}")
|
|
472
|
+
return tokenizer
|