deepchopper 1.3.0__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. deepchopper/__init__.py +9 -0
  2. deepchopper/__init__.pyi +67 -0
  3. deepchopper/__main__.py +4 -0
  4. deepchopper/cli.py +260 -0
  5. deepchopper/data/__init__.py +15 -0
  6. deepchopper/data/components/__init__.py +1 -0
  7. deepchopper/data/encode_fq.py +41 -0
  8. deepchopper/data/fq_datamodule.py +352 -0
  9. deepchopper/data/hg_data.py +39 -0
  10. deepchopper/data/only_fq.py +388 -0
  11. deepchopper/deepchopper.abi3.so +0 -0
  12. deepchopper/eval.py +86 -0
  13. deepchopper/models/__init__.py +4 -0
  14. deepchopper/models/basic_module.py +243 -0
  15. deepchopper/models/callbacks.py +57 -0
  16. deepchopper/models/cnn.py +54 -0
  17. deepchopper/models/components/__init__.py +1 -0
  18. deepchopper/models/dc_hg.py +163 -0
  19. deepchopper/models/llm/__init__.py +32 -0
  20. deepchopper/models/llm/caduceus.py +55 -0
  21. deepchopper/models/llm/components.py +99 -0
  22. deepchopper/models/llm/head.py +102 -0
  23. deepchopper/models/llm/hyena.py +41 -0
  24. deepchopper/models/llm/metric.py +44 -0
  25. deepchopper/models/llm/tokenizer.py +205 -0
  26. deepchopper/models/transformer.py +107 -0
  27. deepchopper/py.typed +0 -0
  28. deepchopper/train.py +109 -0
  29. deepchopper/ui/__init__.py +1 -0
  30. deepchopper/ui/main.py +189 -0
  31. deepchopper/utils/__init__.py +37 -0
  32. deepchopper/utils/instantiators.py +54 -0
  33. deepchopper/utils/logging_utils.py +53 -0
  34. deepchopper/utils/preprocess.py +62 -0
  35. deepchopper/utils/print.py +102 -0
  36. deepchopper/utils/pylogger.py +57 -0
  37. deepchopper/utils/rich_utils.py +100 -0
  38. deepchopper/utils/utils.py +138 -0
  39. deepchopper-1.3.0.dist-info/METADATA +254 -0
  40. deepchopper-1.3.0.dist-info/RECORD +43 -0
  41. deepchopper-1.3.0.dist-info/WHEEL +4 -0
  42. deepchopper-1.3.0.dist-info/entry_points.txt +2 -0
  43. deepchopper-1.3.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,99 @@
1
+ import torch
2
+ from torch import nn
3
+ from transformers import AutoModel, PretrainedConfig, PreTrainedModel
4
+ from transformers.modeling_outputs import TokenClassifierOutput
5
+
6
+ from .head import TokenClassificationHead
7
+
8
+ HyenadnaMaxLengths = {
9
+ "hyenadna-tiny-1k-seqlen": 1024,
10
+ "hyenadna-small-32k-seqlen": 32768,
11
+ "hyenadna-medium-160k-seqlen": 160000,
12
+ "hyenadna-medium-450k-seqlen": 450000, # T4 up to here
13
+ "hyenadna-large-1m-seqlen": 1_000_000, # only A100 (paid tier)
14
+ }
15
+
16
+
17
+ class TokenClassificationConfig(PretrainedConfig):
18
+ """Configuration class to store the model's hyperparameters."""
19
+
20
+ model_type = "token-classification"
21
+
22
+ def __init__(
23
+ self,
24
+ input_size: int = 256,
25
+ lin1_size: int = 1024,
26
+ lin2_size: int = 1024,
27
+ num_class: int = 2,
28
+ *,
29
+ use_identity_layer_for_qual: bool = True,
30
+ use_qual: bool = True,
31
+ **kwargs,
32
+ ):
33
+ self.input_size = input_size
34
+ self.lin1_size = lin1_size
35
+ self.lin2_size = lin2_size
36
+ self.num_class = num_class
37
+ self.use_identity_layer_for_qual = use_identity_layer_for_qual
38
+ self.use_qual = use_qual
39
+ super().__init__(**kwargs)
40
+
41
+
42
+ class TokenClassification(PreTrainedModel):
43
+ """Token classification model."""
44
+
45
+ config_class = TokenClassificationConfig
46
+
47
+ def __init__(
48
+ self,
49
+ config,
50
+ hyenadna_model: str = "hyenadna-small-32k-seqlen",
51
+ **kwargs,
52
+ ):
53
+ super().__init__(config, **kwargs)
54
+ self.num_class = config.num_class
55
+ self.hyenadna_model_name = hyenadna_model
56
+ self.hyenadna = AutoModel.from_pretrained(f"LongSafari/{hyenadna_model}-hf", trust_remote_code=True)
57
+
58
+ self.head = TokenClassificationHead(
59
+ input_size=config.input_size,
60
+ lin1_size=config.lin1_size,
61
+ lin2_size=config.lin2_size,
62
+ num_class=config.num_class,
63
+ use_identity_layer_for_qual=config.use_identity_layer_for_qual,
64
+ use_qual=config.use_qual,
65
+ )
66
+
67
+ # Initialize weights and apply final processing
68
+ self.post_init()
69
+
70
+ def forward(
71
+ self,
72
+ input_ids: torch.Tensor,
73
+ labels: torch.Tensor,
74
+ input_quals: torch.Tensor,
75
+ inputs_embeds: torch.FloatTensor | None = None,
76
+ output_hidden_states: bool | None = None,
77
+ return_dict: bool | None = None,
78
+ ):
79
+ transformer_outputs = self.hyenadna(
80
+ input_ids,
81
+ inputs_embeds=inputs_embeds,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ )
85
+
86
+ _batch_size = input_ids.shape[0]
87
+ hidden_states = transformer_outputs[0]
88
+
89
+ logits = self.head(hidden_states, input_quals)
90
+ labels = labels.to(logits.device)
91
+ loss_fct = nn.CrossEntropyLoss()
92
+
93
+ loss = loss_fct(logits.view(-1, self.num_class), labels.view(-1))
94
+
95
+ return TokenClassifierOutput(
96
+ loss=loss,
97
+ logits=logits,
98
+ hidden_states=transformer_outputs.hidden_states,
99
+ )
@@ -0,0 +1,102 @@
1
+ import torch
2
+ import torch.nn.functional as F # noqa: N812
3
+ from torch import nn
4
+
5
+
6
+ class TokenClassificationCnnHead(nn.Module):
7
+ def __init__(
8
+ self,
9
+ input_size,
10
+ number_of_classes,
11
+ num_filters,
12
+ filter_sizes,
13
+ ):
14
+ super().__init__()
15
+ self.number_of_classes = number_of_classes
16
+
17
+ self.qual_linear1 = nn.Linear(num_filters, number_of_classes)
18
+
19
+ layers = []
20
+ in_channels = input_size
21
+
22
+ for idx, fs in enumerate(filter_sizes):
23
+ layers.append(nn.Conv1d(in_channels=in_channels, out_channels=num_filters, kernel_size=fs, padding="same"))
24
+ layers.append(nn.BatchNorm1d(num_filters[idx]))
25
+ layers.append(nn.ReLU())
26
+ in_channels = num_filters[idx]
27
+
28
+ self.model = nn.Sequential(*layers)
29
+ self.dense = nn.Linear(in_channels, number_of_classes)
30
+
31
+ def forward(self, x: torch.Tensor, input_quals: torch.Tensor):
32
+ x = F.relu(x + self.qual_linear1(input_quals.unsqueeze(-1)))
33
+ x = x.transpose(1, 2) # (batch, num_filters, seq_len)
34
+ x = self.model(x) # (batch, num_filters, seq_len)
35
+ x = x.transpose(1, 2) # (batch, seq_len, num_filters)
36
+ return self.dense(x)
37
+
38
+
39
+ class TokenClassificationHead(nn.Module):
40
+ """Token classification head for the model."""
41
+
42
+ def __init__(
43
+ self,
44
+ input_size: int,
45
+ num_class: int,
46
+ lin1_size: int,
47
+ lin2_size: int,
48
+ *,
49
+ use_identity_layer_for_qual: bool,
50
+ use_qual: bool,
51
+ ):
52
+ """Initialize the neural network model.
53
+
54
+ Parameters:
55
+ input_size (int): The size of the input features.
56
+ lin1_size (int): The size of the first linear layer.
57
+ lin2_size (int): The size of the second linear layer.
58
+ num_class (int): The number of output classes.
59
+ use_identity_layer_for_qual (bool): Whether to use an identity layer for quality.
60
+ use_qual (bool): Whether to use quality in the model.
61
+ """
62
+ if lin1_size != lin2_size:
63
+ msg = f"{lin1_size=} and {lin2_size=} must be equal"
64
+ raise ValueError(msg)
65
+
66
+ super().__init__()
67
+ self.use_qual = use_qual
68
+ self.activation = nn.ReLU()
69
+ self.linear1 = nn.Linear(input_size, lin1_size)
70
+ self.linear2 = nn.Linear(lin1_size, lin2_size)
71
+ self.linear3 = nn.Linear(lin2_size, num_class)
72
+
73
+ self.qual_linear1 = nn.Identity() if use_identity_layer_for_qual else nn.Linear(1, lin1_size)
74
+
75
+ def forward(self, x: torch.Tensor, input_quals: torch.Tensor) -> torch.Tensor:
76
+ """Forward pass through the neural network model.
77
+
78
+ Parameters:
79
+ x (torch.Tensor): Input tensor to the model.
80
+ input_quals (torch.Tensor): Input tensor representing qualities.
81
+
82
+ Returns:
83
+ torch.Tensor: Output tensor from the model.
84
+
85
+ This method performs a forward pass through the neural network model.
86
+ It takes in two input tensors, x and input_quals, and processes them through the model layers.
87
+ The output tensor is returned from the model after passing through the linear and activation layers.
88
+ If the 'use_qual' flag is set to True, the input_quals tensor is used to calculate a residual value that is added to the output tensor before passing through the linear and activation layers again.
89
+ This helps incorporate qualities into the model's predictions.
90
+ If the 'use_qual' flag is set to False, the input_quals tensor is not used and the output tensor from the first linear layer is directly passed through the second linear and activation layers.
91
+ The final output tensor is returned from the model after passing through the third linear layer.
92
+ Note: The activation function used in the model is specified by self.activation and should be set during model initialization.
93
+ """
94
+ output = self.activation(self.linear1(x))
95
+
96
+ if self.use_qual:
97
+ residual = output + self.qual_linear1(input_quals.unsqueeze(-1)) # may add activation
98
+ output = self.activation(self.linear2(residual) + residual)
99
+ else:
100
+ output = self.activation(self.linear2(output))
101
+
102
+ return self.linear3(output)
@@ -0,0 +1,41 @@
1
+ import torch
2
+ from torch import nn
3
+ from transformers import AutoModel
4
+
5
+ # https://github.com/HazyResearch/hyena-dna
6
+
7
+
8
+ class TokenClassificationModule(nn.Module):
9
+ """Token classification model."""
10
+
11
+ def __init__(
12
+ self,
13
+ number_of_classes: int,
14
+ head: nn.Module,
15
+ backbone_name: str = "hyenadna-small-32k-seqlen",
16
+ *,
17
+ freeze_backbone=False,
18
+ ):
19
+ super().__init__()
20
+ self.number_of_classes = number_of_classes
21
+ self.backbone_name = backbone_name
22
+ self.backbone = AutoModel.from_pretrained(f"LongSafari/{backbone_name}-hf", trust_remote_code=True)
23
+ self.head = head
24
+
25
+ if freeze_backbone:
26
+ for param in self.backbone.parameters():
27
+ param.requires_grad = False
28
+
29
+ def forward(
30
+ self,
31
+ input_ids: torch.Tensor,
32
+ input_quals: torch.Tensor,
33
+ ):
34
+ transformer_outputs = self.backbone(
35
+ input_ids,
36
+ inputs_embeds=None,
37
+ output_hidden_states=None,
38
+ return_dict=None,
39
+ )
40
+ hidden_states = transformer_outputs[0]
41
+ return self.head(hidden_states, input_quals)
@@ -0,0 +1,44 @@
1
+ import evaluate
2
+ import numpy as np
3
+
4
+ clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
5
+
6
+ IGNORE_INDEX = -100
7
+
8
+
9
+ def compute_metrics(p):
10
+ """Compute metrics for a given set of predictions and labels.
11
+
12
+ Parameters:
13
+ p (tuple): A tuple containing two numpy arrays - predictions and labels.
14
+ predictions: 3D numpy array of shape (batch_size, sequence_length, num_classes)
15
+ labels: 2D numpy array of shape (batch_size, sequence_length)
16
+
17
+ Returns:
18
+ dict: A dictionary containing computed metrics for the predictions.
19
+
20
+ Raises:
21
+ ValueError: If the input arrays are not of the expected shape.
22
+ """
23
+ predictions, labels = p
24
+ predictions = np.argmax(predictions, axis=2)
25
+ # Initialize lists to hold the filtered predictions and labels
26
+ true_predictions = []
27
+ true_labels = []
28
+
29
+ # Filter out '-100' labels and correspondingly filter predictions
30
+ for prediction, label in zip(predictions, labels, strict=True):
31
+ filtered_prediction = []
32
+ filtered_label = []
33
+
34
+ for p, l in zip(prediction, label, strict=True):
35
+ if l != IGNORE_INDEX:
36
+ filtered_prediction.append(p)
37
+ filtered_label.append(l)
38
+ true_predictions.append(filtered_prediction)
39
+ true_labels.append(filtered_label)
40
+
41
+ for preds, refs in zip(true_predictions, true_labels, strict=True):
42
+ clf_metrics.add_batch(predictions=preds, references=refs)
43
+
44
+ return clf_metrics.compute()
@@ -0,0 +1,205 @@
1
+ from functools import partial
2
+
3
+ import torch
4
+ from datasets import Dataset
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ DataCollatorForTokenClassification,
8
+ )
9
+
10
+ import deepchopper
11
+
12
+ from .metric import IGNORE_INDEX
13
+
14
+
15
+ def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
16
+ """Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer."""
17
+ # To avoid errors when using Feature extractors
18
+ if not hasattr(tokenizer, "deprecation_warnings"):
19
+ return tokenizer.pad(*pad_args, **pad_kwargs)
20
+
21
+ # Save the state of the warning, then disable it
22
+ warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
23
+ tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
24
+
25
+ try:
26
+ padded = tokenizer.pad(*pad_args, **pad_kwargs)
27
+ finally:
28
+ # Restore the state of the warning.
29
+ tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
30
+
31
+ return padded
32
+
33
+
34
+ class DataCollatorForTokenClassificationWithQual(DataCollatorForTokenClassification):
35
+ def torch_call(self, features):
36
+ import torch
37
+
38
+ label_name = "label" if "label" in features[0] else "labels"
39
+ labels = [feature[label_name] for feature in features] if label_name in features[0] else None
40
+
41
+ qual_name = "input_quals"
42
+ qual_pad_token_id = 0
43
+ input_quals = [feature[qual_name] for feature in features]
44
+
45
+ id_name = "id" # for predction dataset
46
+
47
+ no_labels_features = [
48
+ {k: v for k, v in feature.items() if k not in [qual_name, label_name, id_name]} for feature in features
49
+ ]
50
+
51
+ batch = pad_without_fast_tokenizer_warning(
52
+ self.tokenizer,
53
+ no_labels_features,
54
+ padding=self.padding,
55
+ max_length=self.max_length,
56
+ pad_to_multiple_of=self.pad_to_multiple_of,
57
+ return_tensors="pt",
58
+ )
59
+
60
+ if labels is None:
61
+ return batch
62
+
63
+ sequence_length = batch["input_ids"].shape[1]
64
+ padding_side = self.tokenizer.padding_side
65
+
66
+ def to_list(tensor_or_iterable):
67
+ if isinstance(tensor_or_iterable, torch.Tensor):
68
+ return tensor_or_iterable.tolist()
69
+ return list(tensor_or_iterable)
70
+
71
+ if padding_side == "right":
72
+ batch[label_name] = [
73
+ to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
74
+ ]
75
+ batch[qual_name] = [
76
+ to_list(qual) + [qual_pad_token_id] * (sequence_length - len(qual)) for qual in input_quals
77
+ ]
78
+ else:
79
+ batch[label_name] = [
80
+ [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
81
+ ]
82
+ batch[qual_name] = [
83
+ [qual_pad_token_id] * (sequence_length - len(qual)) + to_list(qual) for qual in input_quals
84
+ ]
85
+
86
+ batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int8)
87
+ batch[qual_name] = torch.tensor(batch[qual_name], dtype=torch.float32)
88
+
89
+ # for predction dataset and save id feature
90
+ if id_name in features[0]:
91
+ batch[id_name] = torch.tensor([to_list(feature[id_name]) for feature in features], dtype=torch.int8)
92
+
93
+ return batch
94
+
95
+
96
+ def load_tokenizer_from_hyena_model(model_name):
97
+ max_lengths = {
98
+ "hyenadna-tiny-1k-seqlen": 1024,
99
+ "hyenadna-small-32k-seqlen": 32768,
100
+ "hyenadna-medium-160k-seqlen": 160000,
101
+ "hyenadna-medium-450k-seqlen": 450000, # T4 up to here
102
+ "hyenadna-large-1m-seqlen": 1_000_000, # only A100 (paid tier)
103
+ }
104
+
105
+ if model_name not in max_lengths:
106
+ msg = f"Model name {model_name} not found in available models."
107
+ raise ValueError(msg)
108
+
109
+ max_length = max_lengths[model_name]
110
+ # bfloat16 for better speed and reduced memory usage
111
+ model_name = f"LongSafari/{model_name}-hf"
112
+ return AutoTokenizer.from_pretrained(
113
+ model_name, max_length=max_length, truncation=True, padding=True, trust_remote_code=True, force_download=False
114
+ )
115
+
116
+
117
+ def to_list(data):
118
+ return list(data)
119
+
120
+
121
+ def tokenize_and_align_labels_and_quals(data, tokenizer, max_length, pad_qual=0, pad_label=IGNORE_INDEX):
122
+ tokenized_inputs = tokenizer(data["seq"], max_length=max_length, truncation=True, padding=True)
123
+
124
+ if isinstance(data["qual"], bytes):
125
+ data["qual"] = torch.Tensor(to_list(data["qual"]))
126
+
127
+ if len(data["seq"]) >= max_length:
128
+ if data["target"][1] + 2 > max_length:
129
+ labels = torch.tensor([*deepchopper.vectorize_target(0, 0, max_length - 1), pad_label])
130
+ quals = torch.cat((data["qual"][: max_length - 1], torch.tensor([pad_qual]))).float()
131
+ normalized_quals = torch.nn.functional.normalize(quals, dim=0)
132
+ else:
133
+ labels = torch.tensor([*deepchopper.vectorize_targets(data["target"].numpy(), max_length - 1), pad_label])
134
+ quals = torch.cat((data["qual"][: max_length - 1], torch.tensor([pad_qual]))).float()
135
+ normalized_quals = torch.nn.functional.normalize(quals, dim=0)
136
+ else:
137
+ labels = torch.tensor([*deepchopper.vectorize_targets(data["target"].numpy(), len(data["seq"])), pad_label])
138
+ quals = torch.cat((data["qual"], torch.tensor([pad_qual]))).float()
139
+ normalized_quals = torch.nn.functional.normalize(quals, dim=0)
140
+
141
+ tokenized_inputs.update({"labels": labels, "input_quals": normalized_quals})
142
+ return tokenized_inputs
143
+
144
+
145
+ def tokenize_and_align_labels_and_quals_ids(
146
+ data, tokenizer, max_length, pad_qual=0, pad_label=IGNORE_INDEX, max_id_length=256
147
+ ):
148
+ tokenized_inputs = tokenizer(data["seq"], max_length=max_length, truncation=True, padding=True)
149
+
150
+ truncation = False
151
+
152
+ # TODO: remove target and labels during prediction
153
+
154
+ if len(data["seq"]) >= max_length:
155
+ truncation = True
156
+ if data["target"][1] + 2 > max_length:
157
+ labels = torch.tensor([*deepchopper.vectorize_target(0, 0, max_length - 1), pad_label])
158
+ quals = torch.cat((data["qual"][: max_length - 1], torch.tensor([pad_qual]))).float()
159
+ normalized_quals = torch.nn.functional.normalize(quals, dim=0)
160
+ else:
161
+ labels = torch.tensor([*deepchopper.vectorize_targets(data["target"].numpy(), max_length - 1), pad_label])
162
+ quals = torch.cat((data["qual"][: max_length - 1], torch.tensor([pad_qual]))).float()
163
+ normalized_quals = torch.nn.functional.normalize(quals, dim=0)
164
+ else:
165
+ labels = torch.tensor([*deepchopper.vectorize_targets(data["target"].numpy(), len(data["seq"])), pad_label])
166
+ quals = torch.cat((data["qual"], torch.tensor([pad_qual]))).float()
167
+ normalized_quals = torch.nn.functional.normalize(quals, dim=0)
168
+
169
+ # change id to ascii values
170
+ new_id = [len(data["id"]), int(truncation)]
171
+ new_id += [ord(char) for char in data["id"]]
172
+ if len(new_id) > max_id_length:
173
+ new_id = new_id[:max_id_length]
174
+ elif len(new_id) < max_id_length:
175
+ new_id += [0] * (max_id_length - len(new_id))
176
+
177
+ tokenized_inputs.update({"labels": labels, "input_quals": normalized_quals, "id": new_id})
178
+ return tokenized_inputs
179
+
180
+
181
+ def tokenize_dataset(dataset, tokenizer, max_length):
182
+ """Tokenizes the input dataset using the provided tokenizer and aligns labels and qualities.
183
+
184
+ Args:
185
+ dataset (Dataset): The input dataset to be tokenized.
186
+ tokenizer (Tokenizer): The tokenizer to be used for tokenization.
187
+ max_length (int): The maximum length of the tokenized sequences.
188
+
189
+ Returns:
190
+ Tokenized dataset with aligned labels and qualities.
191
+
192
+ Raises:
193
+ ValueError: If the dataset is empty or if the tokenizer is not provided.
194
+ TypeError: If the dataset is not of type Dataset or if the tokenizer is not of type Tokenizer.
195
+ """
196
+ if not dataset:
197
+ raise ValueError("Input dataset is empty")
198
+ if not tokenizer:
199
+ raise ValueError("Tokenizer is not provided")
200
+ if not isinstance(dataset, Dataset):
201
+ raise TypeError("Input dataset must be of type Dataset")
202
+
203
+ return dataset.map(
204
+ partial(tokenize_and_align_labels_and_quals, tokenizer=tokenizer, max_length=max_length)
205
+ ).remove_columns(["id", "seq", "qual", "target"])
@@ -0,0 +1,107 @@
1
+ import torch
2
+ import torch.nn.functional as F # noqa: N812
3
+ from torch import nn
4
+ from torch.nn import (
5
+ TransformerDecoder,
6
+ TransformerDecoderLayer,
7
+ TransformerEncoder,
8
+ TransformerEncoderLayer,
9
+ )
10
+
11
+
12
+ class TokenClassificationModule(nn.Module):
13
+ def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, number_of_classes):
14
+ super().__init__()
15
+ self.number_of_classes = number_of_classes
16
+
17
+ self.qual_linear1 = nn.Sequential(
18
+ nn.Linear(1, d_model),
19
+ )
20
+
21
+ self.embedding = nn.Embedding(vocab_size, d_model)
22
+ encoder_layer = TransformerEncoderLayer(
23
+ d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True
24
+ )
25
+
26
+ self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
27
+
28
+ self.classifier = nn.Linear(d_model, number_of_classes)
29
+ self.d_model = d_model
30
+
31
+ def forward(self, src: torch.Tensor, input_quals: torch.Tensor, src_mask=None):
32
+ seq_length = src.size(1)
33
+
34
+ # Calculate positional embeddings
35
+ position = torch.arange(seq_length, dtype=torch.float, device=src.device).unsqueeze(1)
36
+ div_term = torch.exp(
37
+ torch.arange(0, self.d_model, 2, device=src.device).float()
38
+ * (-torch.log(torch.tensor(10000.0)) / self.d_model)
39
+ )
40
+
41
+ pos_embedding = torch.zeros((seq_length, self.d_model), device=src.device)
42
+ pos_embedding[:, 0::2] = torch.sin(position * div_term)
43
+ pos_embedding[:, 1::2] = torch.cos(position * div_term)
44
+
45
+ src = F.relu(self.embedding(src) + self.qual_linear1(input_quals.unsqueeze(-1)))
46
+ src = src + pos_embedding.unsqueeze(0)
47
+
48
+ if src_mask is not None:
49
+ src_mask = src_mask.to(dtype=torch.bool)
50
+
51
+ output = self.transformer_encoder(src, src_key_padding_mask=src_mask)
52
+ return self.classifier(output)
53
+
54
+
55
+ class Seq2SeqTokenClassifier(nn.Module):
56
+ def __init__(
57
+ self,
58
+ vocab_size,
59
+ d_model,
60
+ nhead,
61
+ num_encoder_layers,
62
+ num_decoder_layers,
63
+ dim_feedforward,
64
+ number_of_classes,
65
+ ):
66
+ super().__init__()
67
+ self.d_model = d_model
68
+
69
+ self.embedding = nn.Embedding(vocab_size, d_model)
70
+ self.qual_linear1 = nn.Linear(1, d_model)
71
+
72
+ encoder_layer = TransformerEncoderLayer(
73
+ d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True
74
+ )
75
+ self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
76
+
77
+ decoder_layer = TransformerDecoderLayer(
78
+ d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True
79
+ )
80
+ self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
81
+
82
+ self.classifier = nn.Linear(d_model, number_of_classes)
83
+
84
+ def forward(
85
+ self,
86
+ src: torch.Tensor,
87
+ input_quals: torch.Tensor,
88
+ tgt: torch.Tensor,
89
+ src_mask=None,
90
+ tgt_mask=None,
91
+ ):
92
+ src_embeddings = self.embedding(src)
93
+ qual_embeddings = self.qual_linear1(input_quals.unsqueeze(-1))
94
+
95
+ encoder_input = src_embeddings + qual_embeddings
96
+
97
+ if src_mask is not None:
98
+ src_mask = src_mask.to(dtype=torch.bool)
99
+
100
+ encoder_output = self.transformer_encoder(encoder_input, src_key_padding_mask=src_mask)
101
+
102
+ tgt_embeddings = self.embedding(tgt)
103
+ decoder_output = self.transformer_decoder(
104
+ tgt_embeddings, encoder_output, memory_key_padding_mask=src_mask, tgt_mask=tgt_mask
105
+ )
106
+
107
+ return self.classifier(decoder_output)
deepchopper/py.typed ADDED
File without changes