torchtextclassifiers 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. torchTextClassifiers/__init__.py +12 -48
  2. torchTextClassifiers/dataset/__init__.py +1 -0
  3. torchTextClassifiers/dataset/dataset.py +114 -0
  4. torchTextClassifiers/model/__init__.py +2 -0
  5. torchTextClassifiers/model/components/__init__.py +12 -0
  6. torchTextClassifiers/model/components/attention.py +126 -0
  7. torchTextClassifiers/model/components/categorical_var_net.py +128 -0
  8. torchTextClassifiers/model/components/classification_head.py +43 -0
  9. torchTextClassifiers/model/components/text_embedder.py +220 -0
  10. torchTextClassifiers/model/lightning.py +166 -0
  11. torchTextClassifiers/model/model.py +151 -0
  12. torchTextClassifiers/tokenizers/WordPiece.py +92 -0
  13. torchTextClassifiers/tokenizers/__init__.py +10 -0
  14. torchTextClassifiers/tokenizers/base.py +205 -0
  15. torchTextClassifiers/tokenizers/ngram.py +472 -0
  16. torchTextClassifiers/torchTextClassifiers.py +463 -405
  17. torchTextClassifiers/utilities/__init__.py +0 -3
  18. torchTextClassifiers/utilities/plot_explainability.py +184 -0
  19. torchtextclassifiers-0.1.0.dist-info/METADATA +73 -0
  20. torchtextclassifiers-0.1.0.dist-info/RECORD +21 -0
  21. {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-0.1.0.dist-info}/WHEEL +1 -1
  22. torchTextClassifiers/classifiers/base.py +0 -83
  23. torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
  24. torchTextClassifiers/classifiers/fasttext/core.py +0 -269
  25. torchTextClassifiers/classifiers/fasttext/model.py +0 -752
  26. torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
  27. torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
  28. torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
  29. torchTextClassifiers/factories.py +0 -34
  30. torchTextClassifiers/utilities/checkers.py +0 -108
  31. torchTextClassifiers/utilities/preprocess.py +0 -82
  32. torchTextClassifiers/utilities/utils.py +0 -346
  33. torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
  34. torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
@@ -1,752 +0,0 @@
1
- """FastText model components.
2
-
3
- This module contains the PyTorch model, Lightning module, and dataset classes
4
- for FastText classification. Consolidates what was previously in pytorch_model.py,
5
- lightning_module.py, and dataset.py.
6
- """
7
-
8
- import os
9
- import logging
10
- from typing import List, Union
11
- import torch
12
- import pytorch_lightning as pl
13
- from torch import nn
14
- from torchmetrics import Accuracy
15
-
16
- try:
17
- from captum.attr import LayerIntegratedGradients
18
- HAS_CAPTUM = True
19
- except ImportError:
20
- HAS_CAPTUM = False
21
-
22
- from ...utilities.utils import (
23
- compute_preprocessed_word_score,
24
- compute_word_score,
25
- explain_continuous,
26
- )
27
- from ...utilities.checkers import validate_categorical_inputs
28
-
29
- logger = logging.getLogger(__name__)
30
-
31
- logging.basicConfig(
32
- level=logging.INFO,
33
- format="%(asctime)s - %(name)s - %(message)s",
34
- datefmt="%Y-%m-%d %H:%M:%S",
35
- handlers=[logging.StreamHandler()],
36
- )
37
-
38
-
39
- # ============================================================================
40
- # PyTorch Model
41
- # ============================================================================
42
-
43
- class FastTextModel(nn.Module):
44
- """FastText Pytorch Model."""
45
-
46
- def __init__(
47
- self,
48
- embedding_dim: int,
49
- num_classes: int,
50
- tokenizer=None,
51
- num_rows: int = None,
52
- categorical_vocabulary_sizes: List[int] = None,
53
- categorical_embedding_dims: Union[List[int], int] = None,
54
- padding_idx: int = 0,
55
- sparse: bool = True,
56
- direct_bagging: bool = False,
57
- ):
58
- """
59
- Constructor for the FastTextModel class.
60
-
61
- Args:
62
- embedding_dim (int): Dimension of the text embedding space.
63
- buckets (int): Number of rows in the embedding matrix.
64
- num_classes (int): Number of classes.
65
- categorical_vocabulary_sizes (List[int]): List of the number of
66
- modalities for additional categorical features.
67
- padding_idx (int, optional): Padding index for the text
68
- descriptions. Defaults to 0.
69
- sparse (bool): Indicates if Embedding layer is sparse.
70
- direct_bagging (bool): Use EmbeddingBag instead of Embedding for the text embedding.
71
- """
72
- super(FastTextModel, self).__init__()
73
-
74
- if isinstance(categorical_embedding_dims, int):
75
- self.average_cat_embed = True # if provided categorical embedding dims is an int, average the categorical embeddings before concatenating to sentence embedding
76
- else:
77
- self.average_cat_embed = False
78
-
79
- categorical_vocabulary_sizes, categorical_embedding_dims, num_categorical_features = (
80
- validate_categorical_inputs(
81
- categorical_vocabulary_sizes,
82
- categorical_embedding_dims,
83
- num_categorical_features=None,
84
- )
85
- )
86
-
87
- assert isinstance(categorical_embedding_dims, list) or categorical_embedding_dims is None, (
88
- "categorical_embedding_dims must be a list of int at this stage"
89
- )
90
-
91
- if categorical_embedding_dims is None:
92
- self.average_cat_embed = False
93
-
94
- if tokenizer is None:
95
- if num_rows is None:
96
- raise ValueError(
97
- "Either tokenizer or num_rows must be provided (number of rows in the embedding matrix)."
98
- )
99
- else:
100
- if num_rows is not None:
101
- if num_rows != tokenizer.num_tokens:
102
- logger.warning(
103
- "num_rows is different from the number of tokens in the tokenizer. Using provided num_rows."
104
- )
105
-
106
- self.num_rows = num_rows
107
-
108
- self.num_classes = num_classes
109
- self.padding_idx = padding_idx
110
- self.tokenizer = tokenizer
111
- self.embedding_dim = embedding_dim
112
- self.direct_bagging = direct_bagging
113
- self.sparse = sparse
114
-
115
- self.categorical_embedding_dims = categorical_embedding_dims
116
-
117
- self.embeddings = (
118
- nn.Embedding(
119
- embedding_dim=embedding_dim,
120
- num_embeddings=num_rows,
121
- padding_idx=padding_idx,
122
- sparse=sparse,
123
- )
124
- if not direct_bagging
125
- else nn.EmbeddingBag(
126
- embedding_dim=embedding_dim,
127
- num_embeddings=num_rows,
128
- padding_idx=padding_idx,
129
- sparse=sparse,
130
- mode="mean",
131
- )
132
- )
133
-
134
- self.categorical_embedding_layers = {}
135
-
136
- # Entry dim for the last layer:
137
- # 1. embedding_dim if no categorical variables or summing the categrical embeddings to sentence embedding
138
- # 2. embedding_dim + cat_embedding_dim if averaging the categorical embeddings before concatenating to sentence embedding (categorical_embedding_dims is a int)
139
- # 3. embedding_dim + sum(categorical_embedding_dims) if concatenating individually the categorical embeddings to sentence embedding (no averaging, categorical_embedding_dims is a list)
140
- dim_in_last_layer = embedding_dim
141
- if self.average_cat_embed:
142
- dim_in_last_layer += categorical_embedding_dims[0]
143
-
144
- if categorical_vocabulary_sizes is not None:
145
- self.no_cat_var = False
146
- for var_idx, num_rows in enumerate(categorical_vocabulary_sizes):
147
- if categorical_embedding_dims is not None:
148
- emb = nn.Embedding(
149
- embedding_dim=categorical_embedding_dims[var_idx], num_embeddings=num_rows
150
- ) # concatenate to sentence embedding
151
- if not self.average_cat_embed:
152
- dim_in_last_layer += categorical_embedding_dims[var_idx]
153
- else:
154
- emb = nn.Embedding(
155
- embedding_dim=embedding_dim, num_embeddings=num_rows
156
- ) # sum to sentence embedding
157
- self.categorical_embedding_layers[var_idx] = emb
158
- setattr(self, "emb_{}".format(var_idx), emb)
159
- else:
160
- self.no_cat_var = True
161
-
162
- self.fc = nn.Linear(dim_in_last_layer, num_classes)
163
-
164
- def forward(self, encoded_text: torch.Tensor, additional_inputs: torch.Tensor) -> torch.Tensor:
165
- """
166
- Memory-efficient forward pass implementation.
167
-
168
- Args:
169
- encoded_text (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text
170
- additional_inputs (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features)
171
-
172
- Returns:
173
- torch.Tensor: Model output scores for each class
174
- """
175
- batch_size = encoded_text.size(0)
176
-
177
- # Ensure correct dtype and device once
178
- if encoded_text.dtype != torch.long:
179
- encoded_text = encoded_text.to(torch.long)
180
-
181
- # Compute text embeddings
182
- if self.direct_bagging:
183
- x_text = self.embeddings(encoded_text) # (batch_size, embedding_dim)
184
- else:
185
- # Compute embeddings and averaging in a memory-efficient way
186
- x_text = self.embeddings(encoded_text) # (batch_size, seq_len, embedding_dim)
187
- # Calculate non-zero tokens mask once
188
- non_zero_mask = (x_text.sum(-1) != 0).float() # (batch_size, seq_len)
189
- token_counts = non_zero_mask.sum(-1, keepdim=True) # (batch_size, 1)
190
-
191
- # Sum and average in place
192
- x_text = (x_text * non_zero_mask.unsqueeze(-1)).sum(
193
- dim=1
194
- ) # (batch_size, embedding_dim)
195
- x_text = torch.div(x_text, token_counts.clamp(min=1.0))
196
- x_text = torch.nan_to_num(x_text, 0.0)
197
-
198
- # Handle categorical variables efficiently
199
- if not self.no_cat_var and additional_inputs.numel() > 0:
200
- cat_embeds = []
201
- # Process categorical embeddings in batch
202
- for i, (_, embed_layer) in enumerate(self.categorical_embedding_layers.items()):
203
- cat_input = additional_inputs[:, i].long()
204
-
205
- # Check if categorical values are within valid range and clamp if needed
206
- vocab_size = embed_layer.num_embeddings
207
- max_val = cat_input.max().item()
208
- min_val = cat_input.min().item()
209
-
210
- if max_val >= vocab_size or min_val < 0:
211
- logger.warning(f"Categorical feature {i}: values range [{min_val}, {max_val}] exceed vocabulary size {vocab_size}. Clamping to valid range [0, {vocab_size - 1}]")
212
- # Clamp values to valid range
213
- cat_input = torch.clamp(cat_input, 0, vocab_size - 1)
214
-
215
- cat_embed = embed_layer(cat_input)
216
- if cat_embed.dim() > 2:
217
- cat_embed = cat_embed.squeeze(1)
218
- cat_embeds.append(cat_embed)
219
-
220
- if cat_embeds: # If we have categorical embeddings
221
- if self.categorical_embedding_dims is not None:
222
- if self.average_cat_embed:
223
- # Stack and average in one operation
224
- x_cat = torch.stack(cat_embeds, dim=0).mean(dim=0)
225
- x_combined = torch.cat([x_text, x_cat], dim=1)
226
- else:
227
- # Optimize concatenation
228
- x_combined = torch.cat([x_text] + cat_embeds, dim=1)
229
- else:
230
- # Sum embeddings efficiently
231
- x_combined = x_text + torch.stack(cat_embeds, dim=0).sum(dim=0)
232
- else:
233
- x_combined = x_text
234
- else:
235
- x_combined = x_text
236
-
237
- # Final linear layer
238
- return self.fc(x_combined)
239
-
240
- def predict(
241
- self,
242
- text: List[str],
243
- categorical_variables: List[List[int]],
244
- top_k=1,
245
- explain=False,
246
- preprocess=True,
247
- ):
248
- """
249
- Args:
250
- text (List[str]): A list of text observations.
251
- params (Optional[Dict[str, Any]]): Additional parameters to
252
- pass to the model for inference.
253
- top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
254
- explain (bool): launch gradient integration to have an explanation of the prediction (default: False)
255
- preprocess (bool): If True, preprocess text. Needs unidecode library.
256
-
257
- Returns:
258
- if explain is False:
259
- predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query.
260
- confidence (torch.Tensor, shape (len(text), top_k)): A tensor array containing the corresponding confidence scores.
261
- if explain is True:
262
- predictions (torch.Tensor, shape (len(text), top_k)): Containing the top_k most likely codes to the query.
263
- confidence (torch.Tensor, shape (len(text), top_k)): Corresponding confidence scores.
264
- all_attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text.
265
- x (torch.Tensor): A tensor containing the token indices of the text.
266
- id_to_token_dicts (List[Dict[int, str]]): A list of dictionaries mapping token indices to tokens (one for each sentence).
267
- token_to_id_dicts (List[Dict[str, int]]): A list of dictionaries mapping tokens to token indices: the reverse of those in id_to_token_dicts.
268
- text (list[str]): A plist containing the preprocessed text (one line for each sentence).
269
- """
270
-
271
- flag_change_embed = False
272
- if explain:
273
- if not HAS_CAPTUM:
274
- raise ImportError(
275
- "Captum is not installed and is required for explainability. Run 'pip install torchFastText[explainability]'."
276
- )
277
- if self.direct_bagging:
278
- # Get back the classical embedding layer for explainability
279
- new_embed_layer = nn.Embedding(
280
- embedding_dim=self.embedding_dim,
281
- num_embeddings=self.num_rows,
282
- padding_idx=self.padding_idx,
283
- sparse=self.sparse,
284
- )
285
- new_embed_layer.load_state_dict(
286
- self.embeddings.state_dict()
287
- ) # No issues, as exactly the same parameters
288
- self.embeddings = new_embed_layer
289
- self.direct_bagging = (
290
- False # To inform the forward pass that we are not using EmbeddingBag anymore
291
- )
292
- flag_change_embed = True
293
-
294
- lig = LayerIntegratedGradients(
295
- self, self.embeddings
296
- ) # initialize a Captum layer gradient integrator
297
-
298
- self.eval()
299
- batch_size = len(text)
300
-
301
- indices_batch, id_to_token_dicts, token_to_id_dicts = self.tokenizer.tokenize(
302
- text, text_tokens=False, preprocess=preprocess
303
- )
304
-
305
- padding_index = (
306
- self.tokenizer.get_buckets() + self.tokenizer.get_nwords()
307
- ) # padding index, the integer value of the padding token
308
-
309
- padded_batch = torch.nn.utils.rnn.pad_sequence(
310
- indices_batch,
311
- batch_first=True,
312
- padding_value=padding_index,
313
- ) # (batch_size, seq_len) - Tokenized (int) + padded text
314
-
315
- x = padded_batch
316
-
317
- if not self.no_cat_var:
318
- other_features = []
319
- # Transpose categorical_variables to iterate over features instead of samples
320
- categorical_variables_transposed = categorical_variables.T
321
- for i, categorical_variable in enumerate(categorical_variables_transposed):
322
- other_features.append(
323
- torch.tensor(categorical_variable).reshape(batch_size, -1).to(torch.int64)
324
- )
325
-
326
- other_features = torch.stack(other_features).reshape(batch_size, -1).long()
327
- else:
328
- other_features = torch.empty(batch_size)
329
-
330
- pred = self(
331
- x, other_features
332
- ) # forward pass, contains the prediction scores (len(text), num_classes)
333
- label_scores = pred.detach().cpu()
334
- label_scores_topk = torch.topk(label_scores, k=top_k, dim=1)
335
-
336
- predictions = label_scores_topk.indices # get the top_k most likely predictions
337
- confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores
338
-
339
- if explain:
340
- assert not self.direct_bagging, "Direct bagging should be False for explainability"
341
- all_attributions = []
342
- for k in range(top_k):
343
- attributions = lig.attribute(
344
- (x, other_features), target=torch.Tensor(predictions[:, k]).long()
345
- ) # (batch_size, seq_len)
346
- attributions = attributions.sum(dim=-1)
347
- all_attributions.append(attributions.detach().cpu())
348
-
349
- all_attributions = torch.stack(all_attributions, dim=1) # (batch_size, top_k, seq_len)
350
-
351
- # Get back to initial embedding layer:
352
- # EmbeddingBag -> Embedding -> EmbeddingBag
353
- # or keep Embedding with no change
354
- if flag_change_embed:
355
- new_embed_layer = nn.EmbeddingBag(
356
- embedding_dim=self.embedding_dim,
357
- num_embeddings=self.num_rows,
358
- padding_idx=self.padding_idx,
359
- sparse=self.sparse,
360
- )
361
- new_embed_layer.load_state_dict(
362
- self.embeddings.state_dict()
363
- ) # No issues, as exactly the same parameters
364
- self.embeddings = new_embed_layer
365
- self.direct_bagging = True
366
- return (
367
- predictions,
368
- confidence,
369
- all_attributions,
370
- x,
371
- id_to_token_dicts,
372
- token_to_id_dicts,
373
- text,
374
- )
375
- else:
376
- return predictions, confidence
377
-
378
- def predict_and_explain(self, text, categorical_variables, top_k=1, n=5, cutoff=0.65):
379
- """
380
- Args:
381
- text (List[str]): A list of sentences.
382
- params (Optional[Dict[str, Any]]): Additional parameters to
383
- pass to the model for inference.
384
- top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
385
- n (int): mapping processed to original words: max number of candidate processed words to consider per original word (default: 5)
386
- cutoff (float): mapping processed to original words: minimum similarity score to consider a candidate processed word (default: 0.75)
387
-
388
- Returns:
389
- predictions (torch.Tensor, shape (len(text), top_k)): Containing the top_k most likely codes to the query.
390
- confidence (torch.Tensor, shape (len(text), top_k)): Corresponding confidence scores.
391
- all_scores (List[List[List[float]]]): For each sentence, list of the top_k lists of attributions for each word in the sentence (one for each pred).
392
- """
393
-
394
- # Step 1: Get the predictions, confidence scores and attributions at token level
395
- (
396
- pred,
397
- confidence,
398
- all_attr,
399
- tokenized_text,
400
- id_to_token_dicts,
401
- token_to_id_dicts,
402
- processed_text,
403
- ) = self.predict(
404
- text=text, categorical_variables=categorical_variables, top_k=top_k, explain=True
405
- )
406
-
407
- tokenized_text_tokens = self.tokenizer._tokenized_text_in_tokens(
408
- tokenized_text, id_to_token_dicts
409
- )
410
-
411
- # Step 2: Map the attributions at token level to the processed words
412
- processed_word_to_score_dicts, processed_word_to_token_idx_dicts = (
413
- compute_preprocessed_word_score(
414
- processed_text,
415
- tokenized_text_tokens,
416
- all_attr,
417
- id_to_token_dicts,
418
- token_to_id_dicts,
419
- min_n=self.tokenizer.min_n,
420
- padding_index=self.padding_idx,
421
- end_of_string_index=0,
422
- )
423
- )
424
-
425
- # Step 3: Map the processed words to the original words
426
- all_scores, orig_to_processed_mappings = compute_word_score(
427
- processed_word_to_score_dicts, text, n=n, cutoff=cutoff
428
- )
429
-
430
- # Step 2bis: Get the attributions at letter level
431
- all_scores_letters = explain_continuous(
432
- text,
433
- processed_text,
434
- tokenized_text_tokens,
435
- orig_to_processed_mappings,
436
- processed_word_to_token_idx_dicts,
437
- all_attr,
438
- top_k,
439
- )
440
-
441
- return pred, confidence, all_scores, all_scores_letters
442
-
443
-
444
- # ============================================================================
445
- # PyTorch Lightning Module
446
- # ============================================================================
447
-
448
- class FastTextModule(pl.LightningModule):
449
- """Pytorch Lightning Module for FastTextModel."""
450
-
451
- def __init__(
452
- self,
453
- model: FastTextModel,
454
- loss,
455
- optimizer,
456
- optimizer_params,
457
- scheduler,
458
- scheduler_params,
459
- scheduler_interval="epoch",
460
- **kwargs,
461
- ):
462
- """
463
- Initialize FastTextModule.
464
-
465
- Args:
466
- model: Model.
467
- loss: Loss
468
- optimizer: Optimizer
469
- optimizer_params: Optimizer parameters.
470
- scheduler: Scheduler.
471
- scheduler_params: Scheduler parameters.
472
- scheduler_interval: Scheduler interval.
473
- """
474
- super().__init__()
475
- self.save_hyperparameters(ignore=["model", "loss"])
476
-
477
- self.model = model
478
- self.loss = loss
479
- self.accuracy_fn = Accuracy(task="multiclass", num_classes=self.model.num_classes)
480
- self.optimizer = optimizer
481
- self.optimizer_params = optimizer_params
482
- self.scheduler = scheduler
483
- self.scheduler_params = scheduler_params
484
- self.scheduler_interval = scheduler_interval
485
-
486
- def forward(self, inputs) -> torch.Tensor:
487
- """
488
- Perform forward-pass.
489
-
490
- Args:
491
- batch (List[torch.LongTensor]): Batch to perform forward-pass on.
492
-
493
- Returns (torch.Tensor): Prediction.
494
- """
495
- return self.model(inputs[0], inputs[1])
496
-
497
- def training_step(self, batch, batch_idx: int) -> torch.Tensor:
498
- """
499
- Training step.
500
-
501
- Args:
502
- batch (List[torch.LongTensor]): Training batch.
503
- batch_idx (int): Batch index.
504
-
505
- Returns (torch.Tensor): Loss tensor.
506
- """
507
-
508
- inputs, targets = batch[:-1], batch[-1]
509
- outputs = self.forward(inputs)
510
- loss = self.loss(outputs, targets)
511
- self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True)
512
- accuracy = self.accuracy_fn(outputs, targets)
513
- self.log("train_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True)
514
-
515
- torch.cuda.empty_cache()
516
-
517
- return loss
518
-
519
- def validation_step(self, batch, batch_idx: int):
520
- """
521
- Validation step.
522
-
523
- Args:
524
- batch (List[torch.LongTensor]): Validation batch.
525
- batch_idx (int): Batch index.
526
-
527
- Returns (torch.Tensor): Loss tensor.
528
- """
529
- inputs, targets = batch[:-1], batch[-1]
530
- outputs = self.forward(inputs)
531
- loss = self.loss(outputs, targets)
532
- self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
533
-
534
- accuracy = self.accuracy_fn(outputs, targets)
535
- self.log("val_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True)
536
- return loss
537
-
538
- def test_step(self, batch, batch_idx: int):
539
- """
540
- Test step.
541
-
542
- Args:
543
- batch (List[torch.LongTensor]): Test batch.
544
- batch_idx (int): Batch index.
545
-
546
- Returns (torch.Tensor): Loss tensor.
547
- """
548
- inputs, targets = batch[:-1], batch[-1]
549
- outputs = self.forward(inputs)
550
- loss = self.loss(outputs, targets)
551
-
552
- accuracy = self.accuracy_fn(outputs, targets)
553
-
554
- return loss, accuracy
555
-
556
- def configure_optimizers(self):
557
- """
558
- Configure optimizer for Pytorch lighting.
559
-
560
- Returns: Optimizer and scheduler for pytorch lighting.
561
- """
562
- optimizer = self.optimizer(self.parameters(), **self.optimizer_params)
563
-
564
- # Only use scheduler if it's not ReduceLROnPlateau or if we can ensure val_loss is available
565
- # For complex training setups, sometimes val_loss is not available every epoch
566
- if hasattr(self.scheduler, '__name__') and 'ReduceLROnPlateau' in self.scheduler.__name__:
567
- # For ReduceLROnPlateau, use train_loss as it's always available
568
- scheduler = self.scheduler(optimizer, **self.scheduler_params)
569
- scheduler_config = {
570
- "scheduler": scheduler,
571
- "monitor": "train_loss",
572
- "interval": self.scheduler_interval,
573
- }
574
- return [optimizer], [scheduler_config]
575
- else:
576
- # For other schedulers (StepLR, etc.), no monitoring needed
577
- scheduler = self.scheduler(optimizer, **self.scheduler_params)
578
- return [optimizer], [scheduler]
579
-
580
-
581
- # ============================================================================
582
- # Dataset
583
- # ============================================================================
584
-
585
- class FastTextModelDataset(torch.utils.data.Dataset):
586
- """FastTextModelDataset class."""
587
-
588
- def __init__(
589
- self,
590
- categorical_variables: List[List[int]],
591
- texts: List[str],
592
- tokenizer, # NGramTokenizer
593
- outputs: List[int] = None,
594
- **kwargs,
595
- ):
596
- """
597
- Constructor for the TorchDataset class.
598
-
599
- Args:
600
- categorical_variables (List[List[int]]): The elements of this list
601
- are the values of each categorical variable across the dataset.
602
- text (List[str]): List of text descriptions.
603
- y (List[int]): List of outcomes.
604
- tokenizer (Tokenizer): Tokenizer.
605
- """
606
-
607
- if categorical_variables is not None and len(categorical_variables) != len(texts):
608
- raise ValueError("Categorical variables and texts must have the same length.")
609
-
610
- if outputs is not None and len(outputs) != len(texts):
611
- raise ValueError("Outputs and texts must have the same length.")
612
-
613
- self.categorical_variables = categorical_variables
614
- self.texts = texts
615
- self.outputs = outputs
616
- self.tokenizer = tokenizer
617
-
618
- def __len__(self) -> int:
619
- """
620
- Returns length of the data.
621
-
622
- Returns:
623
- int: Number of observations.
624
- """
625
- return len(self.texts)
626
-
627
- def __str__(self) -> str:
628
- """
629
- Returns description of the Dataset.
630
-
631
- Returns:
632
- str: Description.
633
- """
634
- return f"<FastTextModelDataset(N={len(self)})>"
635
-
636
- def __getitem__(self, index: int) -> List:
637
- """
638
- Returns observation for a given index.
639
-
640
- Args:
641
- index (int): Index.
642
-
643
- Returns:
644
- List[int, str]: Observation with given index.
645
- """
646
- categorical_variables = (
647
- self.categorical_variables[index] if self.categorical_variables is not None else None
648
- )
649
- text = self.texts[index]
650
-
651
- if self.outputs is not None:
652
- y = self.outputs[index]
653
- return text, categorical_variables, y
654
- else:
655
- return text, categorical_variables
656
-
657
- def collate_fn(self, batch):
658
- """
659
- Efficient batch processing without explicit loops.
660
-
661
- Args:
662
- batch: Data batch.
663
-
664
- Returns:
665
- Tuple[torch.LongTensor]: Observation with given index.
666
- """
667
-
668
- # Unzip the batch in one go using zip(*batch)
669
- if self.outputs is not None:
670
- text, *categorical_vars, y = zip(*batch)
671
- else:
672
- text, *categorical_vars = zip(*batch)
673
-
674
- # Convert text to indices in parallel using map
675
- indices_batch = list(map(lambda x: self.tokenizer.indices_matrix(x)[0], text))
676
-
677
- # Get padding index once
678
- padding_index = self.tokenizer.get_buckets() + self.tokenizer.get_nwords()
679
-
680
- # Pad sequences efficiently
681
- padded_batch = torch.nn.utils.rnn.pad_sequence(
682
- indices_batch,
683
- batch_first=True,
684
- padding_value=padding_index,
685
- )
686
-
687
- # Handle categorical variables efficiently
688
- if self.categorical_variables is not None:
689
- categorical_tensors = torch.stack(
690
- [
691
- torch.tensor(cat_var, dtype=torch.float32)
692
- for cat_var in categorical_vars[
693
- 0
694
- ] # Access first element since zip returns tuple
695
- ]
696
- )
697
- else:
698
- categorical_tensors = torch.empty(
699
- padded_batch.shape[0], 1, dtype=torch.float32, device=padded_batch.device
700
- )
701
-
702
- if self.outputs is not None:
703
- # Convert labels to tensor in one go
704
- y = torch.tensor(y, dtype=torch.long)
705
- return (padded_batch, categorical_tensors, y)
706
- else:
707
- return (padded_batch, categorical_tensors)
708
-
709
- def create_dataloader(
710
- self,
711
- batch_size: int,
712
- shuffle: bool = False,
713
- drop_last: bool = False,
714
- num_workers: int = os.cpu_count() - 1,
715
- pin_memory: bool = True,
716
- persistent_workers: bool = True,
717
- **kwargs,
718
- ) -> torch.utils.data.DataLoader:
719
- """
720
- Creates a Dataloader from the FastTextModelDataset.
721
- Use collate_fn() to tokenize and pad the sequences.
722
-
723
- Args:
724
- batch_size (int): Batch size.
725
- shuffle (bool, optional): Shuffle option. Defaults to False.
726
- drop_last (bool, optional): Drop last option. Defaults to False.
727
- num_workers (int, optional): Number of workers. Defaults to os.cpu_count() - 1.
728
- pin_memory (bool, optional): Set True if working on GPU, False if CPU. Defaults to True.
729
- persistent_workers (bool, optional): Set True for training, False for inference. Defaults to True.
730
- **kwargs: Additional arguments for PyTorch DataLoader.
731
-
732
- Returns:
733
- torch.utils.data.DataLoader: Dataloader.
734
- """
735
-
736
- logger.info(f"Creating DataLoader with {num_workers} workers.")
737
-
738
- # persistent_workers requires num_workers > 0
739
- if num_workers == 0:
740
- persistent_workers = False
741
-
742
- return torch.utils.data.DataLoader(
743
- dataset=self,
744
- batch_size=batch_size,
745
- collate_fn=self.collate_fn,
746
- shuffle=shuffle,
747
- drop_last=drop_last,
748
- pin_memory=pin_memory,
749
- num_workers=num_workers,
750
- persistent_workers=persistent_workers,
751
- **kwargs,
752
- )