torchtextclassifiers 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. torchTextClassifiers/__init__.py +12 -48
  2. torchTextClassifiers/dataset/__init__.py +1 -0
  3. torchTextClassifiers/dataset/dataset.py +114 -0
  4. torchTextClassifiers/model/__init__.py +2 -0
  5. torchTextClassifiers/model/components/__init__.py +12 -0
  6. torchTextClassifiers/model/components/attention.py +126 -0
  7. torchTextClassifiers/model/components/categorical_var_net.py +128 -0
  8. torchTextClassifiers/model/components/classification_head.py +43 -0
  9. torchTextClassifiers/model/components/text_embedder.py +220 -0
  10. torchTextClassifiers/model/lightning.py +166 -0
  11. torchTextClassifiers/model/model.py +151 -0
  12. torchTextClassifiers/tokenizers/WordPiece.py +92 -0
  13. torchTextClassifiers/tokenizers/__init__.py +10 -0
  14. torchTextClassifiers/tokenizers/base.py +205 -0
  15. torchTextClassifiers/tokenizers/ngram.py +472 -0
  16. torchTextClassifiers/torchTextClassifiers.py +463 -405
  17. torchTextClassifiers/utilities/__init__.py +0 -3
  18. torchTextClassifiers/utilities/plot_explainability.py +184 -0
  19. torchtextclassifiers-0.1.0.dist-info/METADATA +73 -0
  20. torchtextclassifiers-0.1.0.dist-info/RECORD +21 -0
  21. {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-0.1.0.dist-info}/WHEEL +1 -1
  22. torchTextClassifiers/classifiers/base.py +0 -83
  23. torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
  24. torchTextClassifiers/classifiers/fasttext/core.py +0 -269
  25. torchTextClassifiers/classifiers/fasttext/model.py +0 -752
  26. torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
  27. torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
  28. torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
  29. torchTextClassifiers/factories.py +0 -34
  30. torchTextClassifiers/utilities/checkers.py +0 -108
  31. torchTextClassifiers/utilities/preprocess.py +0 -82
  32. torchTextClassifiers/utilities/utils.py +0 -346
  33. torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
  34. torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
@@ -0,0 +1,205 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import asdict, dataclass
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ try:
9
+ from tokenizers import Tokenizer
10
+ from transformers import AutoTokenizer, PreTrainedTokenizerFast
11
+
12
+ HAS_HF = True
13
+ except ImportError:
14
+ HAS_HF = False
15
+
16
+
17
+ @dataclass
18
+ class TokenizerOutput:
19
+ input_ids: torch.Tensor # shape: (batch_size, seq_len)
20
+ attention_mask: torch.Tensor # shape: (batch_size, seq_len)
21
+ offset_mapping: Optional[torch.Tensor] = None # shape: (batch_size, seq_len, 2)
22
+ word_ids: Optional[np.ndarray] = None # shape: (batch_size, seq_len)
23
+
24
+ def to_dict(self) -> Dict[str, Any]:
25
+ return asdict(self)
26
+
27
+ @classmethod
28
+ def from_dict(cls, data: Dict[str, Any]) -> "TokenizerOutput":
29
+ return cls(**data)
30
+
31
+ def __post_init__(self):
32
+ # --- Basic type checks ---
33
+ if not isinstance(self.input_ids, torch.Tensor):
34
+ raise TypeError(f"token_ids must be a torch.Tensor, got {type(self.input_ids)}")
35
+ if not isinstance(self.attention_mask, torch.Tensor):
36
+ raise TypeError(
37
+ f"attention_mask must be a torch.Tensor, got {type(self.attention_mask)}"
38
+ )
39
+ if self.offset_mapping is not None and not isinstance(self.offset_mapping, torch.Tensor):
40
+ raise TypeError(
41
+ f"offset_mapping must be a torch.Tensor or None, got {type(self.offset_mapping)}"
42
+ )
43
+ if self.word_ids is not None and not isinstance(self.word_ids, np.ndarray):
44
+ raise TypeError(f"word_ids must be a numpy.ndarray or None, got {type(self.word_ids)}")
45
+
46
+ # --- Shape consistency checks ---
47
+ if self.input_ids.shape != self.attention_mask.shape:
48
+ raise ValueError(
49
+ f"Shape mismatch: token_ids {self.token_ids.shape} and attention_mask {self.attention_mask.shape}"
50
+ )
51
+
52
+ if self.offset_mapping is not None:
53
+ expected_shape = (*self.input_ids.shape, 2)
54
+ if self.offset_mapping.shape != expected_shape:
55
+ raise ValueError(
56
+ f"offset_mapping should have shape {expected_shape}, got {self.offset_mapping.shape}"
57
+ )
58
+
59
+ if self.word_ids is not None:
60
+ if self.word_ids.shape != self.input_ids.shape:
61
+ raise ValueError(
62
+ f"word_ids should have shape {self.input_ids.shape}, got {self.word_ids.shape}"
63
+ )
64
+
65
+
66
+ class BaseTokenizer(ABC):
67
+ def __init__(
68
+ self,
69
+ vocab_size: int,
70
+ padding_idx: int,
71
+ output_vectorized: bool = False,
72
+ output_dim: Optional[int] = None,
73
+ ):
74
+ """
75
+ Base class for tokenizers.
76
+ Args:
77
+ vocab_size (int): Size of the vocabulary.
78
+ output_vectorized (bool): Whether the tokenizer outputs vectorized tokens.
79
+ True for instance for a TF-IDF tokenizer.
80
+ """
81
+
82
+ self.vocab_size = vocab_size
83
+ self.output_vectorized = output_vectorized
84
+ self.output_dim = output_dim
85
+ self.padding_idx = padding_idx
86
+ if self.output_vectorized:
87
+ if output_dim is None:
88
+ raise ValueError(
89
+ "Tokenizer's output_dim must be provided if output_vectorized is True."
90
+ )
91
+
92
+ @abstractmethod
93
+ def tokenize(self, text: Union[str, List[str]]) -> TokenizerOutput:
94
+ """Tokenizes the raw input text into a list of tokens."""
95
+ pass
96
+
97
+ def __len__(self):
98
+ return self.vocab_size
99
+
100
+ def __repr__(self):
101
+ return f"{self.__class__.__name__}(vocab_size={self.vocab_size}, output_vectorized={self.output_vectorized}, output_dim={self.output_dim})"
102
+
103
+ def __call__(self, text: Union[str, List[str]], **kwargs) -> list:
104
+ return self.tokenize(text, **kwargs)
105
+
106
+
107
+ class HuggingFaceTokenizer(BaseTokenizer):
108
+ def __init__(
109
+ self,
110
+ vocab_size: int,
111
+ output_dim: Optional[int] = None,
112
+ padding_idx: Optional[int] = None,
113
+ trained: bool = False,
114
+ ):
115
+ super().__init__(
116
+ vocab_size, output_vectorized=False, output_dim=output_dim, padding_idx=padding_idx
117
+ ) # it outputs token ids and not vectors
118
+
119
+ self.trained = trained
120
+ self.tokenizer = None
121
+ self.padding_idx = padding_idx
122
+ self.output_dim = output_dim # constant context size for all batch
123
+
124
+ def tokenize(
125
+ self,
126
+ text: Union[str, List[str]],
127
+ return_offsets_mapping: Optional[bool] = False,
128
+ return_word_ids: Optional[bool] = False,
129
+ ) -> list:
130
+ if not self.trained:
131
+ raise RuntimeError("Tokenizer must be trained before tokenization.")
132
+
133
+ # Pad to longest sequence if no output_dim is specified
134
+ padding = True if self.output_dim is None else "max_length"
135
+ truncation = True if self.output_dim is not None else False
136
+
137
+ tokenize_output = self.tokenizer(
138
+ text,
139
+ padding=padding,
140
+ return_tensors="pt",
141
+ truncation=truncation,
142
+ max_length=self.output_dim,
143
+ return_offsets_mapping=return_offsets_mapping,
144
+ ) # method from PreTrainedTokenizerFast
145
+
146
+ encoded_text = tokenize_output["input_ids"]
147
+
148
+ if return_word_ids:
149
+ word_ids = np.array([tokenize_output.word_ids(i) for i in range(len(encoded_text))])
150
+ else:
151
+ word_ids = None
152
+
153
+ return TokenizerOutput(
154
+ input_ids=encoded_text,
155
+ attention_mask=tokenize_output["attention_mask"],
156
+ offset_mapping=tokenize_output.get("offset_mapping", None),
157
+ word_ids=word_ids,
158
+ )
159
+
160
+ @classmethod
161
+ def load_from_pretrained(cls, tokenizer_name: str, output_dim: Optional[int] = None):
162
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
163
+ padding_idx = tokenizer.pad_token_id
164
+ instance = cls(
165
+ vocab_size=len(tokenizer), trained=True, padding_idx=padding_idx, output_dim=output_dim
166
+ )
167
+ instance.tokenizer = tokenizer
168
+ return instance
169
+
170
+ @classmethod
171
+ def load(cls, load_path: str):
172
+ loaded_tokenizer = PreTrainedTokenizerFast(tokenizer_file=load_path)
173
+ instance = cls(vocab_size=len(loaded_tokenizer), trained=True)
174
+ instance.tokenizer = loaded_tokenizer
175
+ # instance._post_training()
176
+ return instance
177
+
178
+ @classmethod
179
+ def load_from_s3(cls, s3_path: str, filesystem):
180
+ if filesystem.exists(s3_path) is False:
181
+ raise FileNotFoundError(
182
+ f"Tokenizer not found at {s3_path}. Please train it first (see src/train_tokenizers)."
183
+ )
184
+
185
+ with filesystem.open(s3_path, "rb") as f:
186
+ json_str = f.read().decode("utf-8")
187
+
188
+ tokenizer_obj = Tokenizer.from_str(json_str)
189
+ tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_obj)
190
+ instance = cls(vocab_size=len(tokenizer), trained=True)
191
+ instance.tokenizer = tokenizer
192
+ instance._post_training()
193
+ return instance
194
+
195
+ def train(self, *args, **kwargs):
196
+ raise NotImplementedError(
197
+ "This tokenizer cannot be trained directly. "
198
+ "Load it from pretrained or implement train() in a subclass."
199
+ )
200
+
201
+ def _post_training(self):
202
+ raise NotImplementedError("_post_training() not implemented for HuggingFaceTokenizer.")
203
+
204
+ def __repr__(self):
205
+ return f"{self.__class__.__name__} \n HuggingFace tokenizer: {self.tokenizer.__repr__()}"
@@ -0,0 +1,472 @@
1
+ import json
2
+ import re
3
+ import unicodedata
4
+ from functools import lru_cache
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput
11
+
12
+ # ============================================================================
13
+ # Optimized normalization
14
+ # ============================================================================
15
+
16
+ _fasttext_non_alnum = re.compile(r"[^a-z0-9]+")
17
+ _fasttext_multi_space = re.compile(r"\s+")
18
+
19
+ # Pre-compile translation table for faster character removal
20
+ _COMBINING_MARKS = {c: None for c in range(0x0300, 0x0370)}
21
+
22
+
23
+ @lru_cache(maxsize=10000)
24
+ def _clean_single_text_cached(text: str) -> str:
25
+ """Cached version of text cleaning - major speedup for repeated texts."""
26
+ t = text.lower()
27
+ t = unicodedata.normalize("NFKD", t)
28
+ # Faster: use translate() instead of list comprehension
29
+ t = t.translate(_COMBINING_MARKS)
30
+ t = _fasttext_non_alnum.sub(" ", t)
31
+ t = _fasttext_multi_space.sub(" ", t)
32
+ return t.strip()
33
+
34
+
35
+ def clean_text_feature(texts: List[str]) -> List[str]:
36
+ """Vectorized text cleaning with caching."""
37
+ return [_clean_single_text_cached(t) for t in texts]
38
+
39
+
40
+ # ============================================================================
41
+ # Optimized hash function
42
+ # ============================================================================
43
+
44
+
45
+ def fast_hash(s: str) -> int:
46
+ """FNV-1a hash - simple and fast."""
47
+ h = 2166136261
48
+ for c in s:
49
+ h ^= ord(c)
50
+ h = (h * 16777619) & 0xFFFFFFFF
51
+ return h
52
+
53
+
54
+ # ============================================================================
55
+ # Pre-computed subword cache
56
+ # ============================================================================
57
+
58
+
59
+ class SubwordCache:
60
+ """Aggressive pre-computation cache for subwords."""
61
+
62
+ def __init__(
63
+ self,
64
+ word_to_id: dict,
65
+ min_n: int,
66
+ max_n: int,
67
+ num_tokens: int,
68
+ nwords: int,
69
+ unk_token_id: int,
70
+ ):
71
+ self.cache = {}
72
+ self.word_to_id = word_to_id
73
+ self.min_n = min_n
74
+ self.max_n = max_n
75
+ self.num_tokens = num_tokens
76
+ self.nwords = nwords
77
+ self.unk_token_id = unk_token_id
78
+
79
+ # Pre-compute for all vocabulary words
80
+ self._precompute_vocab()
81
+
82
+ def _precompute_vocab(self):
83
+ """Pre-compute subwords for entire vocabulary."""
84
+ for word, word_id in self.word_to_id.items():
85
+ self.cache[word] = self._compute_subwords(word, word_id)
86
+
87
+ def _compute_subwords(self, word: str, word_id: Optional[int] = None) -> List[int]:
88
+ """Compute subword indices for a word."""
89
+ indices = []
90
+
91
+ # Add word token if in vocab
92
+ if word_id is not None:
93
+ indices.append(word_id)
94
+
95
+ # Extract character n-grams
96
+ word_tagged = f"<{word}>"
97
+ L = len(word_tagged)
98
+
99
+ for n in range(self.min_n, self.max_n + 1):
100
+ for i in range(L - n + 1):
101
+ ngram = word_tagged[i : i + n]
102
+ if ngram != word and ngram != word_tagged:
103
+ bucket_idx = fast_hash(ngram) % self.num_tokens
104
+ indices.append(3 + self.nwords + bucket_idx)
105
+
106
+ return indices if indices else [self.unk_token_id]
107
+
108
+ def get(self, word: str) -> List[int]:
109
+ """Get subwords with on-demand computation for OOV words."""
110
+ if word not in self.cache:
111
+ word_id = self.word_to_id.get(word)
112
+ self.cache[word] = self._compute_subwords(word, word_id)
113
+ return self.cache[word]
114
+
115
+
116
+ # ============================================================================
117
+ # Vectorized encoding with optional metadata
118
+ # ============================================================================
119
+
120
+
121
+ def encode_batch_vectorized(
122
+ sentences: List[str],
123
+ subword_cache: SubwordCache,
124
+ eos_token_id: int,
125
+ pad_token_id: int,
126
+ max_length: Optional[int] = None,
127
+ truncation: bool = False,
128
+ return_offsets_mapping: bool = False,
129
+ return_word_ids: bool = False,
130
+ force_max_length: bool = False,
131
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[List], Optional[List]]:
132
+ """
133
+ Vectorized batch encoding - processes all sentences together.
134
+ Returns padded tensors directly, with optional offset mappings and word IDs.
135
+
136
+ Args:
137
+ force_max_length: If True and max_length is set, always return tensors of size max_length
138
+ """
139
+ all_ids = []
140
+ all_offsets = [] if return_offsets_mapping else None
141
+ all_word_ids = [] if return_word_ids else None
142
+ max_len = 0
143
+
144
+ # First pass: encode all sentences
145
+ for sentence in sentences:
146
+ ids = []
147
+ offsets = [] if return_offsets_mapping else None
148
+ word_ids = [] if return_word_ids else None
149
+
150
+ words = sentence.split()
151
+ char_offset = 0
152
+
153
+ for word_idx, word in enumerate(words):
154
+ # Find the actual position of this word in the original sentence
155
+ word_start = sentence.find(word, char_offset)
156
+ word_end = word_start + len(word)
157
+ char_offset = word_end
158
+
159
+ # Get subword tokens for this word
160
+ subword_tokens = subword_cache.get(word)
161
+
162
+ for token_id in subword_tokens:
163
+ ids.append(token_id)
164
+
165
+ if return_offsets_mapping:
166
+ # All subword tokens of a word map to the word's character span
167
+ offsets.append((word_start, word_end))
168
+
169
+ if return_word_ids:
170
+ # All subword tokens of a word get the same word_id
171
+ word_ids.append(word_idx)
172
+
173
+ # Add EOS token
174
+ ids.append(eos_token_id)
175
+ if return_offsets_mapping:
176
+ offsets.append((len(sentence), len(sentence))) # EOS has no span
177
+ if return_word_ids:
178
+ word_ids.append(None) # EOS is not part of any word
179
+
180
+ # Truncate if needed
181
+ if truncation and max_length and len(ids) > max_length:
182
+ ids = ids[:max_length]
183
+ if return_offsets_mapping:
184
+ offsets = offsets[:max_length]
185
+ if return_word_ids:
186
+ word_ids = word_ids[:max_length]
187
+
188
+ all_ids.append(ids)
189
+ if return_offsets_mapping:
190
+ all_offsets.append(offsets)
191
+ if return_word_ids:
192
+ all_word_ids.append(word_ids)
193
+ max_len = max(max_len, len(ids))
194
+
195
+ # Determine final sequence length
196
+ if force_max_length and max_length:
197
+ # Always use max_length when force_max_length is True
198
+ seq_len = max_length
199
+ elif max_length and not truncation:
200
+ seq_len = min(max_len, max_length)
201
+ elif max_length:
202
+ seq_len = max_length
203
+ else:
204
+ seq_len = max_len
205
+
206
+ # Pre-allocate tensors
207
+ batch_size = len(sentences)
208
+ input_ids = torch.full((batch_size, seq_len), pad_token_id, dtype=torch.long)
209
+ attention_mask = torch.zeros((batch_size, seq_len), dtype=torch.long)
210
+
211
+ # Fill tensors and pad metadata
212
+ for i, ids in enumerate(all_ids):
213
+ length = min(len(ids), seq_len)
214
+ input_ids[i, :length] = torch.tensor(ids[:length], dtype=torch.long)
215
+ attention_mask[i, :length] = 1
216
+
217
+ # Pad offsets and word_ids to match sequence length
218
+ if return_offsets_mapping:
219
+ # Pad with (0, 0) for padding tokens
220
+ all_offsets[i] = all_offsets[i][:length] + [(0, 0)] * (seq_len - length)
221
+
222
+ if return_word_ids:
223
+ # Pad with None for padding tokens
224
+ all_word_ids[i] = all_word_ids[i][:length] + [None] * (seq_len - length)
225
+
226
+ return input_ids, attention_mask, all_offsets, all_word_ids
227
+
228
+
229
+ # ============================================================================
230
+ # NGramTokenizer - Optimized
231
+ # ============================================================================
232
+
233
+
234
+ class NGramTokenizer(BaseTokenizer):
235
+ """
236
+ Heavily optimized FastText N-gram tokenizer with:
237
+ - Pre-computed subword cache for entire vocabulary
238
+ - Vectorized batch encoding
239
+ - Cached text normalization
240
+ - Direct tensor operations
241
+ - Optional offset mapping and word ID tracking
242
+ """
243
+
244
+ PAD_TOKEN = "[PAD]"
245
+ UNK_TOKEN = "[UNK]"
246
+ EOS_TOKEN = "</s>"
247
+
248
+ def __init__(
249
+ self,
250
+ min_count: int,
251
+ min_n: int,
252
+ max_n: int,
253
+ num_tokens: int,
254
+ len_word_ngrams: int,
255
+ training_text: Optional[List[str]] = None,
256
+ preprocess: bool = True,
257
+ output_dim: Optional[int] = None,
258
+ **kwargs,
259
+ ):
260
+ if min_n < 2:
261
+ raise ValueError("min_n must be >= 2")
262
+ if max_n > 6:
263
+ raise ValueError("max_n must be <= 6")
264
+
265
+ self.min_count = min_count
266
+ self.min_n = min_n
267
+ self.max_n = max_n
268
+ self.num_tokens = num_tokens
269
+ self.word_ngrams = len_word_ngrams
270
+ self.preprocess = preprocess
271
+
272
+ self.pad_token_id = 0
273
+ self.unk_token_id = 1
274
+ self.eos_token_id = 2
275
+
276
+ if training_text is not None:
277
+ self.train(training_text)
278
+ else:
279
+ self.word_to_id = {}
280
+ self.id_to_word = {}
281
+ self.nwords = 0
282
+ self.subword_cache = None
283
+
284
+ self.vocab_size = 3 + self.nwords + self.num_tokens
285
+
286
+ super().__init__(
287
+ vocab_size=self.vocab_size, padding_idx=self.pad_token_id, output_dim=output_dim
288
+ )
289
+
290
+ def train(self, training_text: List[str]):
291
+ """Build vocabulary from training text."""
292
+ word_counts = {}
293
+ for sent in training_text:
294
+ for w in sent.split():
295
+ word_counts[w] = word_counts.get(w, 0) + 1
296
+
297
+ self.word_to_id = {}
298
+ idx = 3
299
+ for w, c in word_counts.items():
300
+ if c >= self.min_count:
301
+ self.word_to_id[w] = idx
302
+ idx += 1
303
+
304
+ self.nwords = len(self.word_to_id)
305
+ self.vocab_size = 3 + self.nwords + self.num_tokens
306
+
307
+ # Create reverse mapping
308
+ self.id_to_word = {v: k for k, v in self.word_to_id.items()}
309
+ self.id_to_word[self.pad_token_id] = self.PAD_TOKEN
310
+ self.id_to_word[self.unk_token_id] = self.UNK_TOKEN
311
+ self.id_to_word[self.eos_token_id] = self.EOS_TOKEN
312
+
313
+ # Pre-compute all subwords for vocabulary
314
+ print(f"Pre-computing subwords for {self.nwords} vocabulary words...")
315
+ self.subword_cache = SubwordCache(
316
+ self.word_to_id, self.min_n, self.max_n, self.num_tokens, self.nwords, self.unk_token_id
317
+ )
318
+ print("✓ Subword cache built")
319
+
320
+ def tokenize(
321
+ self,
322
+ text: Union[str, List[str]],
323
+ return_offsets_mapping: bool = False,
324
+ return_word_ids: bool = False,
325
+ **kwargs,
326
+ ) -> TokenizerOutput:
327
+ """
328
+ Optimized tokenization with vectorized operations.
329
+
330
+ Args:
331
+ text: Single string or list of strings to tokenize
332
+ padding: Padding strategy ('longest' or 'max_length')
333
+ max_length: Maximum sequence length
334
+ truncation: Whether to truncate sequences exceeding max_length
335
+ return_offsets_mapping: If True, return character offsets for each token
336
+ return_word_ids: If True, return word indices for each token
337
+
338
+ Returns:
339
+ TokenizerOutput with input_ids, attention_mask, and optionally
340
+ offset_mapping and word_ids
341
+ """
342
+ is_single = isinstance(text, str)
343
+ if is_single:
344
+ text = [text]
345
+
346
+ # Fast cached text cleaning
347
+ if self.preprocess:
348
+ text = clean_text_feature(text)
349
+
350
+ if self.output_dim is not None:
351
+ max_length = self.output_dim
352
+ truncation = True
353
+ else:
354
+ max_length = None
355
+ truncation = False
356
+
357
+ # Vectorized encoding
358
+ input_ids, attention_mask, offsets, word_ids = encode_batch_vectorized(
359
+ text,
360
+ self.subword_cache,
361
+ self.eos_token_id,
362
+ self.pad_token_id,
363
+ max_length=max_length,
364
+ truncation=truncation,
365
+ return_offsets_mapping=return_offsets_mapping,
366
+ return_word_ids=return_word_ids,
367
+ )
368
+
369
+ offsets = torch.tensor(offsets) if return_offsets_mapping else None
370
+ word_ids = np.array(word_ids) if return_word_ids else None
371
+
372
+ return TokenizerOutput(
373
+ input_ids=input_ids,
374
+ attention_mask=attention_mask,
375
+ word_ids=word_ids,
376
+ offset_mapping=offsets,
377
+ )
378
+
379
+ def decode(
380
+ self, token_ids: Union[List[int], torch.Tensor], skip_special_tokens: bool = True
381
+ ) -> str:
382
+ """Decode token IDs back to text."""
383
+ if isinstance(token_ids, torch.Tensor):
384
+ token_ids = token_ids.tolist()
385
+
386
+ tokens = []
387
+ for id_ in token_ids:
388
+ if id_ == self.pad_token_id and skip_special_tokens:
389
+ continue
390
+
391
+ if id_ == self.eos_token_id:
392
+ if not skip_special_tokens:
393
+ tokens.append(self.EOS_TOKEN)
394
+ continue
395
+
396
+ if id_ in self.id_to_word:
397
+ tokens.append(self.id_to_word[id_])
398
+ elif not skip_special_tokens:
399
+ tokens.append(f"[ID:{id_}]")
400
+
401
+ return " ".join(tokens)
402
+
403
+ def batch_decode(
404
+ self, sequences: Union[List[List[int]], torch.Tensor], skip_special_tokens: bool = True
405
+ ) -> List[str]:
406
+ """Decode multiple sequences."""
407
+ if isinstance(sequences, torch.Tensor):
408
+ sequences = sequences.tolist()
409
+ return [self.decode(seq, skip_special_tokens) for seq in sequences]
410
+
411
+ def save_pretrained(self, save_directory: str):
412
+ """Save tokenizer configuration and vocabulary."""
413
+ import os
414
+
415
+ os.makedirs(save_directory, exist_ok=True)
416
+
417
+ config = {
418
+ "min_count": self.min_count,
419
+ "min_n": self.min_n,
420
+ "max_n": self.max_n,
421
+ "num_tokens": self.num_tokens,
422
+ "len_word_ngrams": self.word_ngrams,
423
+ "word_to_id": self.word_to_id,
424
+ "preprocess": self.preprocess,
425
+ "vocab_size": self.vocab_size,
426
+ "nwords": self.nwords,
427
+ }
428
+
429
+ with open(f"{save_directory}/tokenizer.json", "w") as f:
430
+ json.dump(config, f, indent=2)
431
+
432
+ print(f"✓ Tokenizer saved to {save_directory}")
433
+
434
+ @classmethod
435
+ def from_pretrained(cls, directory: str):
436
+ """Load tokenizer from saved configuration."""
437
+ with open(f"{directory}/tokenizer.json", "r") as f:
438
+ config = json.load(f)
439
+
440
+ tokenizer = cls(
441
+ min_count=config["min_count"],
442
+ min_n=config["min_n"],
443
+ max_n=config["max_n"],
444
+ num_tokens=config["num_tokens"],
445
+ len_word_ngrams=config["len_word_ngrams"],
446
+ preprocess=config["preprocess"],
447
+ training_text=None,
448
+ )
449
+
450
+ tokenizer.word_to_id = config["word_to_id"]
451
+ tokenizer.nwords = config["nwords"]
452
+ tokenizer.vocab_size = config["vocab_size"]
453
+
454
+ tokenizer.id_to_word = {v: k for k, v in tokenizer.word_to_id.items()}
455
+ tokenizer.id_to_word[tokenizer.pad_token_id] = cls.PAD_TOKEN
456
+ tokenizer.id_to_word[tokenizer.unk_token_id] = cls.UNK_TOKEN
457
+ tokenizer.id_to_word[tokenizer.eos_token_id] = cls.EOS_TOKEN
458
+
459
+ # Rebuild subword cache
460
+ print("Rebuilding subword cache...")
461
+ tokenizer.subword_cache = SubwordCache(
462
+ tokenizer.word_to_id,
463
+ tokenizer.min_n,
464
+ tokenizer.max_n,
465
+ tokenizer.num_tokens,
466
+ tokenizer.nwords,
467
+ tokenizer.unk_token_id,
468
+ )
469
+ print("✓ Subword cache built")
470
+
471
+ print(f"✓ Tokenizer loaded from {directory}")
472
+ return tokenizer