fine 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. checksums.yaml +7 -0
  2. data/.rspec +3 -0
  3. data/CHANGELOG.md +38 -0
  4. data/Gemfile +6 -0
  5. data/Gemfile.lock +167 -0
  6. data/LICENSE +21 -0
  7. data/README.md +212 -0
  8. data/Rakefile +6 -0
  9. data/docs/installation.md +151 -0
  10. data/docs/tutorials/llm-fine-tuning.md +246 -0
  11. data/docs/tutorials/model-export.md +200 -0
  12. data/docs/tutorials/siglip2-image-classification.md +130 -0
  13. data/docs/tutorials/siglip2-object-recognition.md +203 -0
  14. data/docs/tutorials/siglip2-similarity-search.md +152 -0
  15. data/docs/tutorials/text-classification.md +233 -0
  16. data/docs/tutorials/text-embeddings.md +211 -0
  17. data/examples/basic_classification.rb +70 -0
  18. data/examples/data/tool_calls.jsonl +30 -0
  19. data/examples/demo_training.rb +78 -0
  20. data/examples/finetune_gemma3_tools.rb +135 -0
  21. data/examples/real_llm_test.rb +128 -0
  22. data/examples/real_text_classification_test.rb +90 -0
  23. data/examples/real_text_embedder_test.rb +110 -0
  24. data/examples/real_training_test.rb +88 -0
  25. data/examples/test_export.rb +28 -0
  26. data/examples/test_image_classifier.rb +79 -0
  27. data/examples/test_llm.rb +100 -0
  28. data/examples/test_text_classifier.rb +59 -0
  29. data/lib/fine/callbacks/base.rb +140 -0
  30. data/lib/fine/callbacks/progress_bar.rb +66 -0
  31. data/lib/fine/configuration.rb +106 -0
  32. data/lib/fine/datasets/data_loader.rb +63 -0
  33. data/lib/fine/datasets/image_dataset.rb +203 -0
  34. data/lib/fine/datasets/instruction_dataset.rb +226 -0
  35. data/lib/fine/datasets/text_data_loader.rb +88 -0
  36. data/lib/fine/datasets/text_dataset.rb +266 -0
  37. data/lib/fine/error.rb +49 -0
  38. data/lib/fine/export/gguf_exporter.rb +424 -0
  39. data/lib/fine/export/onnx_exporter.rb +249 -0
  40. data/lib/fine/export.rb +53 -0
  41. data/lib/fine/hub/config_loader.rb +145 -0
  42. data/lib/fine/hub/model_downloader.rb +136 -0
  43. data/lib/fine/hub/safetensors_loader.rb +108 -0
  44. data/lib/fine/image_classifier.rb +256 -0
  45. data/lib/fine/llm.rb +336 -0
  46. data/lib/fine/models/base.rb +48 -0
  47. data/lib/fine/models/bert_encoder.rb +202 -0
  48. data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
  49. data/lib/fine/models/causal_lm.rb +279 -0
  50. data/lib/fine/models/classification_head.rb +24 -0
  51. data/lib/fine/models/gemma3_decoder.rb +244 -0
  52. data/lib/fine/models/llama_decoder.rb +297 -0
  53. data/lib/fine/models/sentence_transformer.rb +202 -0
  54. data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
  55. data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
  56. data/lib/fine/text_classifier.rb +250 -0
  57. data/lib/fine/text_embedder.rb +221 -0
  58. data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
  59. data/lib/fine/training/llm_trainer.rb +212 -0
  60. data/lib/fine/training/text_trainer.rb +275 -0
  61. data/lib/fine/training/trainer.rb +194 -0
  62. data/lib/fine/transforms/compose.rb +28 -0
  63. data/lib/fine/transforms/normalize.rb +33 -0
  64. data/lib/fine/transforms/resize.rb +35 -0
  65. data/lib/fine/transforms/to_tensor.rb +53 -0
  66. data/lib/fine/version.rb +3 -0
  67. data/lib/fine.rb +112 -0
  68. data/mise.toml +2 -0
  69. metadata +240 -0
@@ -0,0 +1,424 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Export
5
+ # Export LLMs to GGUF format for llama.cpp, ollama, etc.
6
+ #
7
+ # @example Basic export
8
+ # llm = Fine::LLM.load("my_llama")
9
+ # Fine::Export::GGUFExporter.export(llm, "model.gguf")
10
+ #
11
+ # @example With quantization
12
+ # Fine::Export::GGUFExporter.export(
13
+ # llm,
14
+ # "model-q4.gguf",
15
+ # quantization: :q4_0
16
+ # )
17
+ class GGUFExporter
18
+ # GGUF magic number and version
19
+ GGUF_MAGIC = 0x46554747 # "GGUF" in little-endian
20
+ GGUF_VERSION = 3
21
+
22
+ # GGUF value types
23
+ GGUF_TYPE_UINT8 = 0
24
+ GGUF_TYPE_INT8 = 1
25
+ GGUF_TYPE_UINT16 = 2
26
+ GGUF_TYPE_INT16 = 3
27
+ GGUF_TYPE_UINT32 = 4
28
+ GGUF_TYPE_INT32 = 5
29
+ GGUF_TYPE_FLOAT32 = 6
30
+ GGUF_TYPE_BOOL = 7
31
+ GGUF_TYPE_STRING = 8
32
+ GGUF_TYPE_ARRAY = 9
33
+ GGUF_TYPE_UINT64 = 10
34
+ GGUF_TYPE_INT64 = 11
35
+ GGUF_TYPE_FLOAT64 = 12
36
+
37
+ # GGML tensor types
38
+ GGML_TYPE_F32 = 0
39
+ GGML_TYPE_F16 = 1
40
+ GGML_TYPE_Q4_0 = 2
41
+ GGML_TYPE_Q4_1 = 3
42
+ GGML_TYPE_Q5_0 = 6
43
+ GGML_TYPE_Q5_1 = 7
44
+ GGML_TYPE_Q8_0 = 8
45
+ GGML_TYPE_Q8_1 = 9
46
+ GGML_TYPE_Q2_K = 10
47
+ GGML_TYPE_Q3_K = 11
48
+ GGML_TYPE_Q4_K = 12
49
+ GGML_TYPE_Q5_K = 13
50
+ GGML_TYPE_Q6_K = 14
51
+ GGML_TYPE_Q8_K = 15
52
+
53
+ QUANTIZATION_TYPES = {
54
+ f32: GGML_TYPE_F32,
55
+ f16: GGML_TYPE_F16,
56
+ q4_0: GGML_TYPE_Q4_0,
57
+ q4_1: GGML_TYPE_Q4_1,
58
+ q5_0: GGML_TYPE_Q5_0,
59
+ q5_1: GGML_TYPE_Q5_1,
60
+ q8_0: GGML_TYPE_Q8_0,
61
+ q4_k: GGML_TYPE_Q4_K,
62
+ q5_k: GGML_TYPE_Q5_K,
63
+ q6_k: GGML_TYPE_Q6_K
64
+ }.freeze
65
+
66
+ class << self
67
+ # Export a Fine::LLM to GGUF format
68
+ #
69
+ # @param llm [Fine::LLM] The LLM to export
70
+ # @param output_path [String] Path for the output GGUF file
71
+ # @param quantization [Symbol] Quantization type (:f16, :q4_0, :q4_k, :q8_0, etc.)
72
+ # @param metadata [Hash] Additional metadata to include
73
+ def export(llm, output_path, quantization: :f16, metadata: {})
74
+ unless llm.is_a?(Fine::LLM)
75
+ raise ExportError, "GGUF export only supports Fine::LLM models"
76
+ end
77
+
78
+ unless llm.model
79
+ raise ExportError, "Model not loaded or trained"
80
+ end
81
+
82
+ exporter = new(llm, output_path, quantization, metadata)
83
+ exporter.export
84
+ end
85
+ end
86
+
87
+ def initialize(llm, output_path, quantization, metadata)
88
+ @llm = llm
89
+ @output_path = output_path
90
+ @quantization = quantization
91
+ @metadata = metadata
92
+ @model = llm.model
93
+ @config = extract_config
94
+ end
95
+
96
+ def export
97
+ File.open(@output_path, "wb") do |file|
98
+ @file = file
99
+
100
+ write_header
101
+ write_metadata
102
+ write_tensors
103
+ end
104
+
105
+ @output_path
106
+ end
107
+
108
+ private
109
+
110
+ def extract_config
111
+ model_config = @model.config
112
+
113
+ {
114
+ vocab_size: model_config.vocab_size,
115
+ hidden_size: model_config.hidden_size,
116
+ intermediate_size: model_config.intermediate_size,
117
+ num_hidden_layers: model_config.num_hidden_layers,
118
+ num_attention_heads: model_config.num_attention_heads,
119
+ num_key_value_heads: model_config.num_key_value_heads || model_config.num_attention_heads,
120
+ max_position_embeddings: model_config.max_position_embeddings || 2048,
121
+ rms_norm_eps: model_config.rms_norm_eps || 1e-6,
122
+ rope_theta: model_config.rope_theta || 10000.0
123
+ }
124
+ end
125
+
126
+ def write_header
127
+ # Magic number
128
+ @file.write([GGUF_MAGIC].pack("V"))
129
+
130
+ # Version
131
+ @file.write([GGUF_VERSION].pack("V"))
132
+
133
+ # Tensor count (will be updated later)
134
+ @tensor_count_pos = @file.pos
135
+ @file.write([0].pack("Q<"))
136
+
137
+ # Metadata KV count (will be updated later)
138
+ @kv_count_pos = @file.pos
139
+ @file.write([0].pack("Q<"))
140
+ end
141
+
142
+ def write_metadata
143
+ kv_count = 0
144
+
145
+ # Architecture
146
+ write_string_kv("general.architecture", "llama")
147
+ kv_count += 1
148
+
149
+ write_string_kv("general.name", @llm.model_id || "fine-tuned-model")
150
+ kv_count += 1
151
+
152
+ # Model parameters
153
+ write_uint32_kv("llama.context_length", @config[:max_position_embeddings])
154
+ kv_count += 1
155
+
156
+ write_uint32_kv("llama.embedding_length", @config[:hidden_size])
157
+ kv_count += 1
158
+
159
+ write_uint32_kv("llama.block_count", @config[:num_hidden_layers])
160
+ kv_count += 1
161
+
162
+ write_uint32_kv("llama.feed_forward_length", @config[:intermediate_size])
163
+ kv_count += 1
164
+
165
+ write_uint32_kv("llama.attention.head_count", @config[:num_attention_heads])
166
+ kv_count += 1
167
+
168
+ write_uint32_kv("llama.attention.head_count_kv", @config[:num_key_value_heads])
169
+ kv_count += 1
170
+
171
+ write_float32_kv("llama.rope.freq_base", @config[:rope_theta])
172
+ kv_count += 1
173
+
174
+ write_float32_kv("llama.attention.layer_norm_rms_epsilon", @config[:rms_norm_eps])
175
+ kv_count += 1
176
+
177
+ # Tokenizer info (if available)
178
+ if @llm.tokenizer
179
+ write_string_kv("tokenizer.ggml.model", "llama")
180
+ kv_count += 1
181
+
182
+ if @llm.tokenizer.respond_to?(:vocab_size)
183
+ write_uint32_kv("llama.vocab_size", @llm.tokenizer.vocab_size)
184
+ kv_count += 1
185
+ end
186
+ end
187
+
188
+ # Custom metadata
189
+ @metadata.each do |key, value|
190
+ case value
191
+ when String
192
+ write_string_kv("general.#{key}", value)
193
+ when Integer
194
+ write_uint32_kv("general.#{key}", value)
195
+ when Float
196
+ write_float32_kv("general.#{key}", value)
197
+ end
198
+ kv_count += 1
199
+ end
200
+
201
+ # Update KV count
202
+ current_pos = @file.pos
203
+ @file.seek(@kv_count_pos)
204
+ @file.write([kv_count].pack("Q<"))
205
+ @file.seek(current_pos)
206
+ end
207
+
208
+ def write_tensors
209
+ tensor_count = 0
210
+ tensor_infos = []
211
+ tensor_data = []
212
+
213
+ state_dict = @model.state_dict
214
+
215
+ state_dict.each do |name, tensor|
216
+ gguf_name = convert_tensor_name(name)
217
+ next unless gguf_name
218
+
219
+ # Quantize tensor
220
+ quantized, dtype = quantize_tensor(tensor, name)
221
+
222
+ tensor_infos << {
223
+ name: gguf_name,
224
+ dims: tensor.shape.reverse, # GGUF uses reversed dimensions
225
+ dtype: dtype
226
+ }
227
+
228
+ tensor_data << quantized
229
+ tensor_count += 1
230
+ end
231
+
232
+ # Write tensor infos
233
+ tensor_infos.each do |info|
234
+ write_tensor_info(info)
235
+ end
236
+
237
+ # Alignment padding
238
+ align_to(32)
239
+
240
+ # Write tensor data
241
+ tensor_data.each_with_index do |data, idx|
242
+ align_to(32)
243
+ @file.write(data)
244
+ end
245
+
246
+ # Update tensor count
247
+ current_pos = @file.pos
248
+ @file.seek(@tensor_count_pos)
249
+ @file.write([tensor_count].pack("Q<"))
250
+ @file.seek(current_pos)
251
+ end
252
+
253
+ def convert_tensor_name(torch_name)
254
+ # Map torch.rb/HuggingFace names to GGUF names
255
+ name = torch_name.dup
256
+
257
+ mappings = {
258
+ "decoder.embed_tokens.weight" => "token_embd.weight",
259
+ "decoder.norm.weight" => "output_norm.weight",
260
+ "lm_head.weight" => "output.weight"
261
+ }
262
+
263
+ return mappings[name] if mappings.key?(name)
264
+
265
+ # Layer mappings
266
+ if name =~ /decoder\.layers\.(\d+)\./
267
+ layer_num = $1
268
+
269
+ layer_mappings = {
270
+ "input_layernorm.weight" => "blk.#{layer_num}.attn_norm.weight",
271
+ "post_attention_layernorm.weight" => "blk.#{layer_num}.ffn_norm.weight",
272
+ "self_attn.q_proj.weight" => "blk.#{layer_num}.attn_q.weight",
273
+ "self_attn.k_proj.weight" => "blk.#{layer_num}.attn_k.weight",
274
+ "self_attn.v_proj.weight" => "blk.#{layer_num}.attn_v.weight",
275
+ "self_attn.o_proj.weight" => "blk.#{layer_num}.attn_output.weight",
276
+ "mlp.gate_proj.weight" => "blk.#{layer_num}.ffn_gate.weight",
277
+ "mlp.up_proj.weight" => "blk.#{layer_num}.ffn_up.weight",
278
+ "mlp.down_proj.weight" => "blk.#{layer_num}.ffn_down.weight"
279
+ }
280
+
281
+ suffix = name.sub(/decoder\.layers\.\d+\./, "")
282
+ return layer_mappings[suffix]
283
+ end
284
+
285
+ nil # Skip unknown tensors
286
+ end
287
+
288
+ def quantize_tensor(tensor, name)
289
+ tensor = tensor.cpu.contiguous
290
+
291
+ # Always keep embeddings and norms in higher precision
292
+ if name.include?("embed") || name.include?("norm") || name.include?("lm_head")
293
+ return [tensor_to_f16(tensor), GGML_TYPE_F16]
294
+ end
295
+
296
+ case @quantization
297
+ when :f32
298
+ [tensor_to_f32(tensor), GGML_TYPE_F32]
299
+ when :f16
300
+ [tensor_to_f16(tensor), GGML_TYPE_F16]
301
+ when :q8_0
302
+ quantize_q8_0(tensor)
303
+ when :q4_0
304
+ quantize_q4_0(tensor)
305
+ when :q4_k, :q5_k, :q6_k
306
+ # K-quants are more complex, fall back to Q8 for now
307
+ quantize_q8_0(tensor)
308
+ else
309
+ [tensor_to_f16(tensor), GGML_TYPE_F16]
310
+ end
311
+ end
312
+
313
+ def tensor_to_f32(tensor)
314
+ tensor.to(:float32).data_ptr_bytes
315
+ end
316
+
317
+ def tensor_to_f16(tensor)
318
+ tensor.to(:float16).data_ptr_bytes
319
+ end
320
+
321
+ def quantize_q8_0(tensor)
322
+ # Q8_0: 8-bit quantization with block size 32
323
+ block_size = 32
324
+ data = tensor.to(:float32).flatten.to_a
325
+
326
+ quantized = []
327
+
328
+ data.each_slice(block_size) do |block|
329
+ block = block + [0.0] * (block_size - block.size) if block.size < block_size
330
+
331
+ # Find scale (max absolute value)
332
+ max_abs = block.map(&:abs).max
333
+ scale = max_abs / 127.0
334
+ scale = 1.0 if scale == 0
335
+
336
+ # Quantize
337
+ quantized << [scale].pack("e") # float16 scale
338
+ block.each do |val|
339
+ q = (val / scale).round.clamp(-128, 127)
340
+ quantized << [q].pack("c")
341
+ end
342
+ end
343
+
344
+ [quantized.join, GGML_TYPE_Q8_0]
345
+ end
346
+
347
+ def quantize_q4_0(tensor)
348
+ # Q4_0: 4-bit quantization with block size 32
349
+ block_size = 32
350
+ data = tensor.to(:float32).flatten.to_a
351
+
352
+ quantized = []
353
+
354
+ data.each_slice(block_size) do |block|
355
+ block = block + [0.0] * (block_size - block.size) if block.size < block_size
356
+
357
+ # Find scale
358
+ max_abs = block.map(&:abs).max
359
+ scale = max_abs / 7.0
360
+ scale = 1.0 if scale == 0
361
+
362
+ # Quantize to 4-bit
363
+ quantized << [scale].pack("e") # float16 scale
364
+
365
+ block.each_slice(2) do |pair|
366
+ q0 = ((pair[0] / scale).round.clamp(-8, 7) + 8) & 0x0F
367
+ q1 = ((pair[1] / scale).round.clamp(-8, 7) + 8) & 0x0F
368
+ quantized << [(q0 | (q1 << 4))].pack("C")
369
+ end
370
+ end
371
+
372
+ [quantized.join, GGML_TYPE_Q4_0]
373
+ end
374
+
375
+ def write_tensor_info(info)
376
+ # Name
377
+ write_string(info[:name])
378
+
379
+ # Number of dimensions
380
+ @file.write([info[:dims].size].pack("V"))
381
+
382
+ # Dimensions
383
+ info[:dims].each do |dim|
384
+ @file.write([dim].pack("Q<"))
385
+ end
386
+
387
+ # Type
388
+ @file.write([info[:dtype]].pack("V"))
389
+
390
+ # Offset (will be calculated later, write 0 for now)
391
+ @file.write([0].pack("Q<"))
392
+ end
393
+
394
+ def write_string_kv(key, value)
395
+ write_string(key)
396
+ @file.write([GGUF_TYPE_STRING].pack("V"))
397
+ write_string(value)
398
+ end
399
+
400
+ def write_uint32_kv(key, value)
401
+ write_string(key)
402
+ @file.write([GGUF_TYPE_UINT32].pack("V"))
403
+ @file.write([value].pack("V"))
404
+ end
405
+
406
+ def write_float32_kv(key, value)
407
+ write_string(key)
408
+ @file.write([GGUF_TYPE_FLOAT32].pack("V"))
409
+ @file.write([value].pack("e"))
410
+ end
411
+
412
+ def write_string(str)
413
+ @file.write([str.bytesize].pack("Q<"))
414
+ @file.write(str)
415
+ end
416
+
417
+ def align_to(alignment)
418
+ current = @file.pos
419
+ padding = (alignment - (current % alignment)) % alignment
420
+ @file.write("\x00" * padding) if padding > 0
421
+ end
422
+ end
423
+ end
424
+ end
@@ -0,0 +1,249 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Export
5
+ # Export models to ONNX format
6
+ #
7
+ # @example Export a text classifier
8
+ # classifier = Fine::TextClassifier.load("my_model")
9
+ # Fine::Export::ONNXExporter.export(classifier, "model.onnx")
10
+ #
11
+ # @example Export with options
12
+ # Fine::Export::ONNXExporter.export(
13
+ # model,
14
+ # "model.onnx",
15
+ # opset_version: 14,
16
+ # dynamic_axes: true
17
+ # )
18
+ class ONNXExporter
19
+ SUPPORTED_TYPES = [
20
+ Fine::TextClassifier,
21
+ Fine::TextEmbedder,
22
+ Fine::ImageClassifier,
23
+ Fine::LLM
24
+ ].freeze
25
+
26
+ class << self
27
+ # Export a Fine model to ONNX format
28
+ #
29
+ # @param fine_model [TextClassifier, TextEmbedder, ImageClassifier, LLM] The model to export
30
+ # @param output_path [String] Path for the output ONNX file
31
+ # @param opset_version [Integer] ONNX opset version (default: 14)
32
+ # @param dynamic_axes [Boolean] Use dynamic axes for variable batch/sequence (default: true)
33
+ # @param quantize [Symbol, nil] Quantization type (:int8, :uint8, nil)
34
+ def export(fine_model, output_path, opset_version: 14, dynamic_axes: true, quantize: nil)
35
+ validate_model(fine_model)
36
+
37
+ model = fine_model.model
38
+ model.eval
39
+
40
+ # Get example inputs based on model type
41
+ example_inputs, input_names, output_names, dynamic_axes_config =
42
+ prepare_export_config(fine_model, dynamic_axes)
43
+
44
+ # Export to ONNX
45
+ Torch::ONNX.export(
46
+ model,
47
+ example_inputs,
48
+ output_path,
49
+ input_names: input_names,
50
+ output_names: output_names,
51
+ dynamic_axes: dynamic_axes_config,
52
+ opset_version: opset_version,
53
+ do_constant_folding: true
54
+ )
55
+
56
+ # Optional quantization
57
+ if quantize
58
+ quantize_model(output_path, quantize)
59
+ end
60
+
61
+ output_path
62
+ end
63
+
64
+ # Export only the encoder/backbone (useful for embeddings)
65
+ #
66
+ # @param fine_model [TextEmbedder, ImageClassifier] Model with encoder
67
+ # @param output_path [String] Output path
68
+ def export_encoder(fine_model, output_path, **options)
69
+ unless fine_model.respond_to?(:model) && fine_model.model.respond_to?(:encoder)
70
+ raise ExportError, "Model does not have an encoder"
71
+ end
72
+
73
+ encoder = fine_model.model.encoder
74
+ encoder.eval
75
+
76
+ example_inputs, input_names, output_names, dynamic_axes_config =
77
+ prepare_encoder_config(fine_model)
78
+
79
+ Torch::ONNX.export(
80
+ encoder,
81
+ example_inputs,
82
+ output_path,
83
+ input_names: input_names,
84
+ output_names: output_names,
85
+ dynamic_axes: dynamic_axes_config,
86
+ opset_version: options[:opset_version] || 14
87
+ )
88
+
89
+ output_path
90
+ end
91
+
92
+ private
93
+
94
+ def validate_model(model)
95
+ unless SUPPORTED_TYPES.any? { |t| model.is_a?(t) }
96
+ raise ExportError, "Unsupported model type: #{model.class}"
97
+ end
98
+
99
+ unless model.model
100
+ raise ExportError, "Model not loaded or trained"
101
+ end
102
+ end
103
+
104
+ def prepare_export_config(fine_model, dynamic_axes)
105
+ case fine_model
106
+ when Fine::TextClassifier, Fine::TextEmbedder
107
+ prepare_text_config(fine_model, dynamic_axes)
108
+ when Fine::ImageClassifier
109
+ prepare_image_config(fine_model, dynamic_axes)
110
+ when Fine::LLM
111
+ prepare_llm_config(fine_model, dynamic_axes)
112
+ end
113
+ end
114
+
115
+ def prepare_text_config(fine_model, dynamic_axes)
116
+ batch_size = 1
117
+ seq_length = fine_model.config.max_length
118
+
119
+ example_inputs = [
120
+ Torch.zeros([batch_size, seq_length], dtype: :int64), # input_ids
121
+ Torch.ones([batch_size, seq_length], dtype: :int64) # attention_mask
122
+ ]
123
+
124
+ input_names = %w[input_ids attention_mask]
125
+
126
+ output_names = if fine_model.is_a?(Fine::TextEmbedder)
127
+ %w[embeddings]
128
+ else
129
+ %w[logits]
130
+ end
131
+
132
+ dynamic_axes_config = if dynamic_axes
133
+ {
134
+ "input_ids" => { 0 => "batch_size", 1 => "sequence" },
135
+ "attention_mask" => { 0 => "batch_size", 1 => "sequence" },
136
+ output_names.first => { 0 => "batch_size" }
137
+ }
138
+ end
139
+
140
+ [example_inputs, input_names, output_names, dynamic_axes_config]
141
+ end
142
+
143
+ def prepare_image_config(fine_model, dynamic_axes)
144
+ # Get image size from config
145
+ image_size = fine_model.config.image_size || 224
146
+ batch_size = 1
147
+
148
+ example_inputs = [
149
+ Torch.zeros([batch_size, 3, image_size, image_size], dtype: :float32)
150
+ ]
151
+
152
+ input_names = %w[pixel_values]
153
+ output_names = %w[logits]
154
+
155
+ dynamic_axes_config = if dynamic_axes
156
+ {
157
+ "pixel_values" => { 0 => "batch_size" },
158
+ "logits" => { 0 => "batch_size" }
159
+ }
160
+ end
161
+
162
+ [example_inputs, input_names, output_names, dynamic_axes_config]
163
+ end
164
+
165
+ def prepare_llm_config(fine_model, dynamic_axes)
166
+ batch_size = 1
167
+ seq_length = 128 # Smaller default for export
168
+
169
+ example_inputs = [
170
+ Torch.zeros([batch_size, seq_length], dtype: :int64) # input_ids
171
+ ]
172
+
173
+ input_names = %w[input_ids]
174
+ output_names = %w[logits]
175
+
176
+ dynamic_axes_config = if dynamic_axes
177
+ {
178
+ "input_ids" => { 0 => "batch_size", 1 => "sequence" },
179
+ "logits" => { 0 => "batch_size", 1 => "sequence" }
180
+ }
181
+ end
182
+
183
+ [example_inputs, input_names, output_names, dynamic_axes_config]
184
+ end
185
+
186
+ def prepare_encoder_config(fine_model)
187
+ case fine_model
188
+ when Fine::TextEmbedder
189
+ batch_size = 1
190
+ seq_length = fine_model.config.max_length
191
+
192
+ example_inputs = [
193
+ Torch.zeros([batch_size, seq_length], dtype: :int64),
194
+ Torch.ones([batch_size, seq_length], dtype: :int64)
195
+ ]
196
+
197
+ input_names = %w[input_ids attention_mask]
198
+ output_names = %w[last_hidden_state]
199
+
200
+ dynamic_axes_config = {
201
+ "input_ids" => { 0 => "batch_size", 1 => "sequence" },
202
+ "attention_mask" => { 0 => "batch_size", 1 => "sequence" },
203
+ "last_hidden_state" => { 0 => "batch_size", 1 => "sequence" }
204
+ }
205
+
206
+ [example_inputs, input_names, output_names, dynamic_axes_config]
207
+ when Fine::ImageClassifier
208
+ image_size = fine_model.config.image_size || 224
209
+
210
+ example_inputs = [
211
+ Torch.zeros([1, 3, image_size, image_size], dtype: :float32)
212
+ ]
213
+
214
+ [example_inputs, %w[pixel_values], %w[features], { "pixel_values" => { 0 => "batch_size" } }]
215
+ end
216
+ end
217
+
218
+ def quantize_model(model_path, quantize_type)
219
+ # Note: Full ONNX quantization requires onnxruntime
220
+ # This is a placeholder for the quantization logic
221
+ require "onnxruntime"
222
+
223
+ quantized_path = model_path.sub(".onnx", "_quantized.onnx")
224
+
225
+ case quantize_type
226
+ when :int8
227
+ # Dynamic INT8 quantization
228
+ OnnxRuntime::Quantization.quantize_dynamic(
229
+ model_path,
230
+ quantized_path,
231
+ weight_type: :int8
232
+ )
233
+ when :uint8
234
+ OnnxRuntime::Quantization.quantize_dynamic(
235
+ model_path,
236
+ quantized_path,
237
+ weight_type: :uint8
238
+ )
239
+ end
240
+
241
+ # Replace original with quantized
242
+ FileUtils.mv(quantized_path, model_path)
243
+ rescue LoadError
244
+ warn "onnxruntime gem not installed, skipping quantization"
245
+ end
246
+ end
247
+ end
248
+ end
249
+ end