fine 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. checksums.yaml +7 -0
  2. data/.rspec +3 -0
  3. data/CHANGELOG.md +38 -0
  4. data/Gemfile +6 -0
  5. data/Gemfile.lock +167 -0
  6. data/LICENSE +21 -0
  7. data/README.md +212 -0
  8. data/Rakefile +6 -0
  9. data/docs/installation.md +151 -0
  10. data/docs/tutorials/llm-fine-tuning.md +246 -0
  11. data/docs/tutorials/model-export.md +200 -0
  12. data/docs/tutorials/siglip2-image-classification.md +130 -0
  13. data/docs/tutorials/siglip2-object-recognition.md +203 -0
  14. data/docs/tutorials/siglip2-similarity-search.md +152 -0
  15. data/docs/tutorials/text-classification.md +233 -0
  16. data/docs/tutorials/text-embeddings.md +211 -0
  17. data/examples/basic_classification.rb +70 -0
  18. data/examples/data/tool_calls.jsonl +30 -0
  19. data/examples/demo_training.rb +78 -0
  20. data/examples/finetune_gemma3_tools.rb +135 -0
  21. data/examples/real_llm_test.rb +128 -0
  22. data/examples/real_text_classification_test.rb +90 -0
  23. data/examples/real_text_embedder_test.rb +110 -0
  24. data/examples/real_training_test.rb +88 -0
  25. data/examples/test_export.rb +28 -0
  26. data/examples/test_image_classifier.rb +79 -0
  27. data/examples/test_llm.rb +100 -0
  28. data/examples/test_text_classifier.rb +59 -0
  29. data/lib/fine/callbacks/base.rb +140 -0
  30. data/lib/fine/callbacks/progress_bar.rb +66 -0
  31. data/lib/fine/configuration.rb +106 -0
  32. data/lib/fine/datasets/data_loader.rb +63 -0
  33. data/lib/fine/datasets/image_dataset.rb +203 -0
  34. data/lib/fine/datasets/instruction_dataset.rb +226 -0
  35. data/lib/fine/datasets/text_data_loader.rb +88 -0
  36. data/lib/fine/datasets/text_dataset.rb +266 -0
  37. data/lib/fine/error.rb +49 -0
  38. data/lib/fine/export/gguf_exporter.rb +424 -0
  39. data/lib/fine/export/onnx_exporter.rb +249 -0
  40. data/lib/fine/export.rb +53 -0
  41. data/lib/fine/hub/config_loader.rb +145 -0
  42. data/lib/fine/hub/model_downloader.rb +136 -0
  43. data/lib/fine/hub/safetensors_loader.rb +108 -0
  44. data/lib/fine/image_classifier.rb +256 -0
  45. data/lib/fine/llm.rb +336 -0
  46. data/lib/fine/models/base.rb +48 -0
  47. data/lib/fine/models/bert_encoder.rb +202 -0
  48. data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
  49. data/lib/fine/models/causal_lm.rb +279 -0
  50. data/lib/fine/models/classification_head.rb +24 -0
  51. data/lib/fine/models/gemma3_decoder.rb +244 -0
  52. data/lib/fine/models/llama_decoder.rb +297 -0
  53. data/lib/fine/models/sentence_transformer.rb +202 -0
  54. data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
  55. data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
  56. data/lib/fine/text_classifier.rb +250 -0
  57. data/lib/fine/text_embedder.rb +221 -0
  58. data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
  59. data/lib/fine/training/llm_trainer.rb +212 -0
  60. data/lib/fine/training/text_trainer.rb +275 -0
  61. data/lib/fine/training/trainer.rb +194 -0
  62. data/lib/fine/transforms/compose.rb +28 -0
  63. data/lib/fine/transforms/normalize.rb +33 -0
  64. data/lib/fine/transforms/resize.rb +35 -0
  65. data/lib/fine/transforms/to_tensor.rb +53 -0
  66. data/lib/fine/version.rb +3 -0
  67. data/lib/fine.rb +112 -0
  68. data/mise.toml +2 -0
  69. metadata +240 -0
@@ -0,0 +1,226 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Models
5
+ # BERT model with classification head for sequence classification
6
+ class BertForSequenceClassification < Base
7
+ attr_reader :encoder, :classifier, :num_labels
8
+
9
+ # Load a pretrained model from Hugging Face Hub
10
+ #
11
+ # @param model_id [String] HuggingFace model ID
12
+ # @param num_labels [Integer] Number of classification labels
13
+ # @param dropout [Float] Dropout rate for classifier
14
+ # @return [BertForSequenceClassification]
15
+ def self.from_pretrained(model_id, num_labels:, dropout: 0.1)
16
+ # Download model files
17
+ downloader = Hub::ModelDownloader.new(model_id)
18
+ model_path = downloader.download
19
+
20
+ # Load config
21
+ config = Hub::ConfigLoader.from_pretrained(model_path)
22
+
23
+ # Create model
24
+ model = new(config, num_labels: num_labels, dropout: dropout)
25
+
26
+ # Load pretrained weights into encoder
27
+ weights_path = downloader.file_path("model.safetensors")
28
+
29
+ if File.exist?(weights_path)
30
+ load_result = load_pretrained_weights(model, weights_path)
31
+
32
+ if load_result[:missing_keys].any?
33
+ # Only warn about unexpected missing keys (classifier is expected to be missing)
34
+ encoder_missing = load_result[:missing_keys].reject { |k| k.include?("classifier") }
35
+ if encoder_missing.any?
36
+ warn "Missing encoder keys: #{encoder_missing.first(5).join(', ')}..."
37
+ end
38
+ end
39
+ end
40
+
41
+ model
42
+ end
43
+
44
+ # Load from saved fine-tuned model
45
+ #
46
+ # @param path [String] Path to saved model directory
47
+ # @return [BertForSequenceClassification]
48
+ def self.load(path)
49
+ raise ModelNotFoundError.new(path) unless File.directory?(path)
50
+
51
+ config = Hub::ConfigLoader.from_pretrained(path)
52
+ num_labels = config.config["num_labels"] || config.config["id2label"]&.size
53
+
54
+ raise ConfigurationError, "Cannot determine num_labels from saved model" unless num_labels
55
+
56
+ model = new(config, num_labels: num_labels)
57
+
58
+ weights_path = File.join(path, "model.safetensors")
59
+ Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
60
+
61
+ model
62
+ end
63
+
64
+ def initialize(config, num_labels:, dropout: 0.1)
65
+ super(config)
66
+
67
+ @num_labels = num_labels
68
+
69
+ # Encoder
70
+ @encoder = BertEncoder.new(config)
71
+
72
+ # Classification head
73
+ @dropout = Torch::NN::Dropout.new(p: dropout)
74
+ @classifier = Torch::NN::Linear.new(config.hidden_size, num_labels)
75
+ end
76
+
77
+ def forward(input_ids, attention_mask: nil, token_type_ids: nil, labels: nil)
78
+ # Get encoder outputs
79
+ encoder_output = @encoder.call(
80
+ input_ids,
81
+ attention_mask: attention_mask,
82
+ token_type_ids: token_type_ids
83
+ )
84
+
85
+ # Use pooled output for classification
86
+ pooled_output = encoder_output[:pooler_output]
87
+ pooled_output = @dropout.call(pooled_output)
88
+
89
+ # Classification logits
90
+ logits = @classifier.call(pooled_output)
91
+
92
+ # Compute loss if labels provided
93
+ if labels
94
+ loss = Torch::NN::Functional.cross_entropy(logits, labels)
95
+ { loss: loss, logits: logits }
96
+ else
97
+ { logits: logits }
98
+ end
99
+ end
100
+
101
+ # Predict class for input
102
+ #
103
+ # @param input_ids [Torch::Tensor] Input token IDs
104
+ # @param attention_mask [Torch::Tensor] Attention mask
105
+ # @return [Torch::Tensor] Predicted class indices
106
+ def predict(input_ids, attention_mask: nil, token_type_ids: nil)
107
+ eval
108
+ Torch.no_grad do
109
+ output = forward(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids)
110
+ output[:logits].argmax(dim: 1)
111
+ end
112
+ end
113
+
114
+ # Get class probabilities
115
+ #
116
+ # @return [Torch::Tensor] Class probabilities
117
+ def predict_proba(input_ids, attention_mask: nil, token_type_ids: nil)
118
+ eval
119
+ Torch.no_grad do
120
+ output = forward(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids)
121
+ Torch::NN::Functional.softmax(output[:logits], dim: 1)
122
+ end
123
+ end
124
+
125
+ # Save the model
126
+ #
127
+ # @param path [String] Directory path
128
+ # @param label_map [Hash, nil] Label mapping
129
+ def save(path, label_map: nil)
130
+ FileUtils.mkdir_p(path)
131
+
132
+ # Save weights
133
+ weights_path = File.join(path, "model.safetensors")
134
+ Safetensors::Torch.save_file(state_dict, weights_path)
135
+
136
+ # Build config
137
+ save_config = @config.to_h.merge(
138
+ "num_labels" => @num_labels,
139
+ "model_type" => "bert_classification"
140
+ )
141
+
142
+ if label_map
143
+ save_config["id2label"] = label_map.invert.sort.to_h
144
+ save_config["label2id"] = label_map
145
+ end
146
+
147
+ config_path = File.join(path, "config.json")
148
+ File.write(config_path, JSON.pretty_generate(save_config))
149
+ end
150
+
151
+ private
152
+
153
+ def self.load_pretrained_weights(model, weights_path)
154
+ tensors = Safetensors::Torch.load_file(weights_path)
155
+
156
+ # Map HuggingFace weight names to our model structure
157
+ mapped = {}
158
+ missing_keys = []
159
+ unexpected_keys = []
160
+
161
+ model_keys = model.state_dict.keys
162
+
163
+ tensors.each do |name, tensor|
164
+ mapped_name = map_bert_weight_name(name)
165
+
166
+ if model_keys.include?(mapped_name)
167
+ mapped[mapped_name] = tensor
168
+ else
169
+ unexpected_keys << name
170
+ end
171
+ end
172
+
173
+ missing_keys = model_keys - mapped.keys
174
+
175
+ # Use no_grad and copy! since torch.rb doesn't support strict: false
176
+ Torch.no_grad do
177
+ state_dict = model.state_dict
178
+ mapped.each do |name, tensor|
179
+ state_dict[name].copy!(tensor) if state_dict.key?(name)
180
+ end
181
+ end
182
+
183
+ { missing_keys: missing_keys, unexpected_keys: unexpected_keys }
184
+ end
185
+
186
+ def self.map_bert_weight_name(hf_name)
187
+ name = hf_name.dup
188
+
189
+ # Embeddings
190
+ name = name.sub("bert.embeddings.word_embeddings", "encoder.word_embeddings")
191
+ name = name.sub("bert.embeddings.position_embeddings", "encoder.position_embeddings")
192
+ name = name.sub("bert.embeddings.token_type_embeddings", "encoder.token_type_embeddings")
193
+ name = name.sub("bert.embeddings.LayerNorm", "encoder.embeddings_layer_norm")
194
+ name = name.sub("embeddings.word_embeddings", "encoder.word_embeddings")
195
+ name = name.sub("embeddings.position_embeddings", "encoder.position_embeddings")
196
+ name = name.sub("embeddings.token_type_embeddings", "encoder.token_type_embeddings")
197
+ name = name.sub("embeddings.LayerNorm", "encoder.embeddings_layer_norm")
198
+
199
+ # Encoder layers
200
+ name = name.gsub("bert.encoder.layer", "encoder.layers")
201
+ name = name.gsub("encoder.layer", "encoder.layers")
202
+
203
+ # Attention
204
+ name = name.gsub(".attention.self.query", ".attention.query")
205
+ name = name.gsub(".attention.self.key", ".attention.key")
206
+ name = name.gsub(".attention.self.value", ".attention.value")
207
+ name = name.gsub(".attention.output.dense", ".attention.out")
208
+ name = name.gsub(".attention.output.LayerNorm", ".attention_layer_norm")
209
+
210
+ # FFN
211
+ name = name.gsub(".intermediate.dense", ".intermediate")
212
+ name = name.gsub(".output.dense", ".output")
213
+ name = name.gsub(".output.LayerNorm", ".output_layer_norm")
214
+
215
+ # Pooler
216
+ name = name.sub("bert.pooler.dense", "encoder.pooler_dense")
217
+ name = name.sub("pooler.dense", "encoder.pooler_dense")
218
+
219
+ # Classifier
220
+ name = name.sub("classifier", "classifier")
221
+
222
+ name
223
+ end
224
+ end
225
+ end
226
+ end
@@ -0,0 +1,279 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Models
5
+ # Causal Language Model for text generation
6
+ #
7
+ # Wraps a decoder model with a language modeling head for next-token prediction.
8
+ class CausalLM < Base
9
+ attr_reader :decoder, :lm_head
10
+
11
+ # Load from HuggingFace Hub
12
+ #
13
+ # @param model_id [String] Model ID (e.g., "google/gemma-2b", "meta-llama/Llama-3.2-1B")
14
+ # @param dtype [Symbol] Data type (:float32, :float16, :bfloat16, or :auto)
15
+ # @return [CausalLM]
16
+ def self.from_pretrained(model_id, dtype: :auto)
17
+ downloader = Hub::ModelDownloader.new(model_id)
18
+ model_path = downloader.download
19
+
20
+ config = Hub::ConfigLoader.from_pretrained(model_path)
21
+
22
+ # Determine dtype from config if auto
23
+ if dtype == :auto
24
+ config_dtype = config.config["torch_dtype"]
25
+ dtype = case config_dtype
26
+ when "bfloat16" then :bfloat16
27
+ when "float16" then :float16
28
+ else :float32
29
+ end
30
+ end
31
+
32
+ model = new(config)
33
+
34
+ # Convert model to target dtype before loading weights (saves memory)
35
+ model.to(dtype) if dtype != :float32
36
+
37
+ # Load weights
38
+ weights_path = downloader.file_path("model.safetensors")
39
+ if File.exist?(weights_path)
40
+ load_pretrained_weights(model, weights_path)
41
+ else
42
+ # Try sharded weights
43
+ load_sharded_weights(model, model_path)
44
+ end
45
+
46
+ model
47
+ end
48
+
49
+ # Load from saved model
50
+ def self.load(path)
51
+ raise ModelNotFoundError.new(path) unless File.directory?(path)
52
+
53
+ config = Hub::ConfigLoader.from_pretrained(path)
54
+ model = new(config)
55
+
56
+ weights_path = File.join(path, "model.safetensors")
57
+ Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
58
+
59
+ model
60
+ end
61
+
62
+ def initialize(config)
63
+ super(config)
64
+
65
+ # Use appropriate decoder based on model type
66
+ @decoder = if config.model_type&.include?("gemma3")
67
+ Gemma3Decoder.new(config)
68
+ else
69
+ LlamaDecoder.new(config)
70
+ end
71
+
72
+ # LM head (often tied to embeddings)
73
+ @lm_head = Torch::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
74
+
75
+ # Optionally tie weights with embeddings
76
+ @tie_word_embeddings = config.config["tie_word_embeddings"] != false
77
+ end
78
+
79
+ def forward(input_ids, attention_mask: nil, labels: nil, return_logits: nil)
80
+ # Default: return logits only if no labels (inference mode)
81
+ return_logits = labels.nil? if return_logits.nil?
82
+
83
+ # Get decoder outputs
84
+ outputs = @decoder.call(input_ids, attention_mask: attention_mask)
85
+ hidden_states = outputs[:last_hidden_state]
86
+
87
+ # LM head
88
+ logits = @lm_head.call(hidden_states)
89
+
90
+ # Compute loss if labels provided
91
+ if labels
92
+ # Shift for next-token prediction
93
+ shift_logits = logits[0.., 0...-1, 0..].contiguous
94
+ shift_labels = labels[0.., 1..].contiguous
95
+
96
+ # Compute cross entropy loss
97
+ vocab_size = logits.size(-1)
98
+ loss = Torch::NN::Functional.cross_entropy(
99
+ shift_logits.view(-1, vocab_size),
100
+ shift_labels.view(-1),
101
+ ignore_index: -100
102
+ )
103
+
104
+ # Don't return logits during training to save memory
105
+ return_logits ? { loss: loss, logits: logits } : { loss: loss }
106
+ else
107
+ { logits: logits }
108
+ end
109
+ end
110
+
111
+ # Generate text autoregressively
112
+ #
113
+ # @param input_ids [Torch::Tensor] Input token IDs
114
+ # @param max_new_tokens [Integer] Maximum tokens to generate
115
+ # @param temperature [Float] Sampling temperature
116
+ # @param top_p [Float] Nucleus sampling threshold
117
+ # @param top_k [Integer] Top-k sampling
118
+ # @param do_sample [Boolean] Whether to sample or use greedy decoding
119
+ # @return [Torch::Tensor] Generated token IDs
120
+ def generate(input_ids, max_new_tokens: 100, temperature: 1.0, top_p: 0.9,
121
+ top_k: 50, do_sample: true, eos_token_id: nil, pad_token_id: nil)
122
+ eval
123
+ generated = input_ids.clone
124
+
125
+ Torch.no_grad do
126
+ max_new_tokens.times do
127
+ # Forward pass
128
+ outputs = forward(generated)
129
+ next_token_logits = outputs[:logits][0.., -1, 0..]
130
+
131
+ # Apply temperature
132
+ next_token_logits = next_token_logits / temperature if temperature != 1.0
133
+
134
+ if do_sample
135
+ # Top-k filtering
136
+ if top_k > 0
137
+ indices_to_remove = next_token_logits < Torch.topk(next_token_logits, top_k).values[0.., -1, nil]
138
+ next_token_logits = next_token_logits.masked_fill(indices_to_remove, -Float::INFINITY)
139
+ end
140
+
141
+ # Top-p (nucleus) filtering
142
+ if top_p < 1.0
143
+ sorted_logits, sorted_indices = Torch.sort(next_token_logits, descending: true)
144
+ cumulative_probs = Torch.cumsum(Torch::NN::Functional.softmax(sorted_logits, dim: -1), dim: -1)
145
+
146
+ # Remove tokens with cumulative probability above threshold
147
+ sorted_indices_to_remove = cumulative_probs > top_p
148
+ sorted_indices_to_remove[0.., 1..] = sorted_indices_to_remove[0.., 0...-1].clone
149
+ sorted_indices_to_remove[0.., 0] = false
150
+
151
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
152
+ next_token_logits = next_token_logits.masked_fill(indices_to_remove, -Float::INFINITY)
153
+ end
154
+
155
+ # Sample
156
+ probs = Torch::NN::Functional.softmax(next_token_logits, dim: -1)
157
+ next_token = Torch.multinomial(probs, num_samples: 1)
158
+ else
159
+ # Greedy
160
+ next_token = next_token_logits.argmax(dim: -1, keepdim: true)
161
+ end
162
+
163
+ # Append to generated
164
+ generated = Torch.cat([generated, next_token], dim: 1)
165
+
166
+ # Check for EOS
167
+ if eos_token_id
168
+ # Handle both single and array EOS token IDs
169
+ eos_ids = eos_token_id.is_a?(Array) ? eos_token_id : [eos_token_id]
170
+ next_token_val = next_token[0, 0].item rescue next_token.to(:int64)[0, 0].item
171
+ break if eos_ids.include?(next_token_val)
172
+ end
173
+ end
174
+ end
175
+
176
+ generated
177
+ end
178
+
179
+ def save(path)
180
+ FileUtils.mkdir_p(path)
181
+
182
+ weights_path = File.join(path, "model.safetensors")
183
+ Safetensors::Torch.save_file(state_dict, weights_path)
184
+
185
+ save_config = @config.to_h.merge("model_type" => "causal_lm")
186
+
187
+ config_path = File.join(path, "config.json")
188
+ File.write(config_path, JSON.pretty_generate(save_config))
189
+ end
190
+
191
+ private
192
+
193
+ def self.load_pretrained_weights(model, weights_path)
194
+ # Load and copy weights one at a time to minimize memory usage
195
+ model_state = model.state_dict
196
+ model_keys = model_state.keys
197
+
198
+ Torch.no_grad do
199
+ Safetensors::Torch.load_file(weights_path).each do |name, tensor|
200
+ mapped_name = map_llama_weight_name(name)
201
+ if model_keys.include?(mapped_name)
202
+ target = model_state[mapped_name]
203
+ # Convert dtype if needed
204
+ tensor = tensor.to(target.dtype) if tensor.dtype != target.dtype
205
+ target.copy!(tensor)
206
+ end
207
+ end
208
+ end
209
+
210
+ # Force garbage collection to free loaded tensors
211
+ GC.start
212
+ end
213
+
214
+ def self.load_sharded_weights(model, model_path)
215
+ # Find all safetensors shards
216
+ shards = Dir.glob(File.join(model_path, "model-*.safetensors")).sort
217
+
218
+ return if shards.empty?
219
+
220
+ model_state = model.state_dict
221
+ model_keys = model_state.keys
222
+
223
+ # Load each shard and copy weights immediately to minimize memory
224
+ Torch.no_grad do
225
+ shards.each do |shard_path|
226
+ Safetensors::Torch.load_file(shard_path).each do |name, tensor|
227
+ mapped_name = map_llama_weight_name(name)
228
+ if model_keys.include?(mapped_name)
229
+ target = model_state[mapped_name]
230
+ tensor = tensor.to(target.dtype) if tensor.dtype != target.dtype
231
+ target.copy!(tensor)
232
+ end
233
+ end
234
+ # GC after each shard to free memory
235
+ GC.start
236
+ end
237
+ end
238
+ end
239
+
240
+ def self.map_llama_weight_name(hf_name)
241
+ name = hf_name.dup
242
+
243
+ # Embeddings
244
+ name = name.sub("model.embed_tokens", "decoder.embed_tokens")
245
+ name = name.sub("lm_head", "lm_head")
246
+
247
+ # Layers
248
+ name = name.gsub("model.layers", "decoder.layers")
249
+
250
+ # Attention (works for both Llama and Gemma)
251
+ name = name.gsub(".self_attn.q_proj", ".self_attn.q_proj")
252
+ name = name.gsub(".self_attn.k_proj", ".self_attn.k_proj")
253
+ name = name.gsub(".self_attn.v_proj", ".self_attn.v_proj")
254
+ name = name.gsub(".self_attn.o_proj", ".self_attn.o_proj")
255
+
256
+ # Gemma 3 QK normalization
257
+ name = name.gsub(".self_attn.q_norm", ".self_attn.q_norm")
258
+ name = name.gsub(".self_attn.k_norm", ".self_attn.k_norm")
259
+
260
+ # MLP
261
+ name = name.gsub(".mlp.gate_proj", ".mlp.gate_proj")
262
+ name = name.gsub(".mlp.up_proj", ".mlp.up_proj")
263
+ name = name.gsub(".mlp.down_proj", ".mlp.down_proj")
264
+
265
+ # Norms (standard)
266
+ name = name.gsub(".input_layernorm", ".input_layernorm")
267
+ name = name.gsub(".post_attention_layernorm", ".post_attention_layernorm")
268
+
269
+ # Gemma 3 additional norms
270
+ name = name.gsub(".pre_feedforward_layernorm", ".pre_feedforward_layernorm")
271
+ name = name.gsub(".post_feedforward_layernorm", ".post_feedforward_layernorm")
272
+
273
+ name = name.sub("model.norm", "decoder.norm")
274
+
275
+ name
276
+ end
277
+ end
278
+ end
279
+ end
@@ -0,0 +1,24 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Models
5
+ # Simple classification head (linear layer)
6
+ class ClassificationHead < Torch::NN::Module
7
+ attr_reader :in_features, :num_classes
8
+
9
+ def initialize(in_features, num_classes, dropout: 0.0)
10
+ super()
11
+ @in_features = in_features
12
+ @num_classes = num_classes
13
+
14
+ @dropout = Torch::NN::Dropout.new(p: dropout) if dropout.positive?
15
+ @classifier = Torch::NN::Linear.new(in_features, num_classes)
16
+ end
17
+
18
+ def forward(x)
19
+ x = @dropout.call(x) if @dropout
20
+ @classifier.call(x)
21
+ end
22
+ end
23
+ end
24
+ end