grnexus 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/lib/grnexus.rb ADDED
@@ -0,0 +1,743 @@
1
+ require 'grnexus_core'
2
+ require 'grnexus_layers'
3
+ require 'grnexus_activations'
4
+ require 'grnexus_normalization'
5
+ require 'grnexus_numeric_proccessing'
6
+ require 'grnexus_text_proccessing'
7
+ require 'grnexus_callbacks'
8
+ require 'grnexus_machine_learning'
9
+ require 'json'
10
+ require 'time'
11
+ require 'zlib'
12
+
13
+ module GRNexus
14
+ class NeuralNetwork
15
+ attr_accessor :layers, :loss_function, :optimizer, :learning_rate, :history, :name
16
+
17
+ def initialize(loss: 'mse', optimizer: 'sgd', learning_rate: 0.01, name: 'model')
18
+ @layers = []
19
+ @loss_function = loss
20
+ @optimizer = optimizer
21
+ @learning_rate = learning_rate
22
+ @name = name
23
+ @history = { loss: [], accuracy: [], val_loss: [], val_accuracy: [], lr: [] }
24
+ @optimizer_state = {}
25
+ @callbacks = []
26
+ @stop_training = false
27
+ end
28
+
29
+ def add(layer)
30
+ @layers << layer
31
+ end
32
+
33
+ def compile(loss: nil, optimizer: nil, learning_rate: nil, metrics: [])
34
+ @loss_function = loss if loss
35
+ @optimizer = optimizer if optimizer
36
+ @learning_rate = learning_rate if learning_rate
37
+ @metrics = metrics
38
+ end
39
+
40
+ def forward(input, training: true)
41
+ output = input
42
+ @layers.each do |layer|
43
+ if layer.is_a?(GRNEXUSLayer::DropoutLayer) || layer.is_a?(GRNEXUSLayer::BatchNormLayer)
44
+ output = layer.forward(output, training: training)
45
+ else
46
+ output = layer.forward(output)
47
+ end
48
+ end
49
+ output
50
+ end
51
+
52
+ def backward(gradient)
53
+ grad = gradient
54
+ @layers.reverse.each do |layer|
55
+ if layer.trainable?
56
+ grad = layer.backward(grad, @learning_rate)
57
+ else
58
+ grad = layer.backward(grad)
59
+ end
60
+ end
61
+ grad
62
+ end
63
+
64
+ def train(x_train, y_train, epochs: 10, batch_size: 32, verbose: true,
65
+ validation_data: nil, callbacks: [], shuffle: true)
66
+ @callbacks = callbacks
67
+ @stop_training = false
68
+
69
+ num_samples = x_train.length
70
+ num_batches = (num_samples.to_f / batch_size).ceil
71
+
72
+ # Validation data
73
+ x_val, y_val = validation_data if validation_data
74
+
75
+ # Callbacks - on_train_begin
76
+ callback_logs = { epochs: epochs, steps: num_batches, model: self }
77
+ @callbacks.each { |cb| cb.on_train_begin(callback_logs) }
78
+
79
+ epochs.times do |epoch|
80
+ break if @stop_training
81
+
82
+ # Callbacks - on_epoch_begin
83
+ epoch_logs = { epoch: epoch, model: self }
84
+ @callbacks.each { |cb| cb.on_epoch_begin(epoch, epoch_logs) }
85
+
86
+ epoch_loss = 0.0
87
+ correct = 0
88
+ total = 0
89
+
90
+ # Shuffle data
91
+ indices = shuffle ? (0...num_samples).to_a.shuffle : (0...num_samples).to_a
92
+
93
+ num_batches.times do |batch_idx|
94
+ batch_start = batch_idx * batch_size
95
+ batch_end = [batch_start + batch_size, num_samples].min
96
+ batch_indices = indices[batch_start...batch_end]
97
+
98
+ x_batch = batch_indices.map { |i| x_train[i] }
99
+ y_batch = batch_indices.map { |i| y_train[i] }
100
+
101
+ # Callbacks - on_batch_begin
102
+ batch_logs = { batch: batch_idx, size: batch_indices.length }
103
+ @callbacks.each { |cb| cb.on_batch_begin(batch_idx, batch_logs) }
104
+
105
+ # Forward pass
106
+ predictions = forward(x_batch, training: true)
107
+
108
+ # Calculate loss
109
+ loss = calculate_loss(predictions, y_batch)
110
+ epoch_loss += loss
111
+
112
+ # Calculate accuracy (for classification)
113
+ if @loss_function == 'cross_entropy'
114
+ correct += calculate_accuracy(predictions, y_batch)
115
+ total += batch_indices.length
116
+ end
117
+
118
+ # Backward pass
119
+ gradient = calculate_gradient(predictions, y_batch)
120
+ backward(gradient)
121
+
122
+ # Callbacks - on_batch_end
123
+ batch_logs[:loss] = loss
124
+ @callbacks.each { |cb| cb.on_batch_end(batch_idx, batch_logs) }
125
+ end
126
+
127
+ avg_loss = epoch_loss / num_batches
128
+ @history[:loss] << avg_loss
129
+ @history[:lr] << @learning_rate
130
+
131
+ # Calculate validation metrics if validation data provided
132
+ if validation_data
133
+ val_results = evaluate(x_val, y_val)
134
+ @history[:val_loss] << val_results[:loss]
135
+ @history[:val_accuracy] << val_results[:accuracy] if val_results[:accuracy]
136
+ end
137
+
138
+ # Prepare epoch end logs
139
+ epoch_end_logs = {
140
+ epoch: epoch,
141
+ loss: avg_loss,
142
+ lr: @learning_rate,
143
+ model: self
144
+ }
145
+
146
+ if @loss_function == 'cross_entropy'
147
+ accuracy = (correct.to_f / total * 100).round(2)
148
+ @history[:accuracy] << accuracy
149
+ epoch_end_logs[:accuracy] = accuracy
150
+ end
151
+
152
+ if validation_data
153
+ epoch_end_logs[:val_loss] = @history[:val_loss].last
154
+ epoch_end_logs[:val_accuracy] = @history[:val_accuracy].last if @history[:val_accuracy].any?
155
+ end
156
+
157
+ # Callbacks - on_epoch_end
158
+ @callbacks.each { |cb| cb.on_epoch_end(epoch, epoch_end_logs) }
159
+
160
+ # Check for early stopping
161
+ @callbacks.each do |cb|
162
+ if cb.respond_to?(:stop_training?) && cb.stop_training?
163
+ @stop_training = true
164
+ break
165
+ end
166
+ end
167
+
168
+ # Verbose output
169
+ if verbose
170
+ output_str = "Epoch #{epoch + 1}/#{epochs} - Loss: #{avg_loss.round(6)}"
171
+ output_str += " - Accuracy: #{accuracy}%" if @loss_function == 'cross_entropy'
172
+ output_str += " - Val Loss: #{@history[:val_loss].last.round(6)}" if validation_data
173
+ output_str += " - Val Accuracy: #{@history[:val_accuracy].last}%" if validation_data && @history[:val_accuracy].any?
174
+ output_str += " - LR: #{@learning_rate}" if @callbacks.any? { |cb| cb.is_a?(GRNEXUSCallbacks::LearningRateScheduler) || cb.is_a?(GRNEXUSCallbacks::ReduceLROnPlateau) }
175
+ puts output_str
176
+ end
177
+ end
178
+
179
+ # Callbacks - on_train_end
180
+ @callbacks.each { |cb| cb.on_train_end({ model: self }) }
181
+
182
+ @history
183
+ end
184
+
185
+ def predict(input)
186
+ forward(input, training: false)
187
+ end
188
+
189
+ def evaluate(x_test, y_test)
190
+ predictions = predict(x_test)
191
+ loss = calculate_loss(predictions, y_test)
192
+
193
+ result = { loss: loss }
194
+
195
+ if @loss_function == 'cross_entropy'
196
+ correct = calculate_accuracy(predictions, y_test)
197
+ accuracy = (correct.to_f / y_test.length * 100).round(2)
198
+ result[:accuracy] = accuracy
199
+ end
200
+
201
+ result
202
+ end
203
+
204
+ def save(filepath)
205
+ # Expandir ruta para manejar rutas relativas y absolutas
206
+ expanded_path = File.expand_path(filepath)
207
+
208
+ # Crear directorio si no existe
209
+ dir = File.dirname(expanded_path)
210
+ FileUtils.mkdir_p(dir) unless File.directory?(dir)
211
+
212
+ model_data = {
213
+ version: '2.0',
214
+ framework: 'GRNexus',
215
+ language: 'Ruby',
216
+ name: @name,
217
+ created_at: Time.now.iso8601,
218
+ architecture: serialize_architecture,
219
+ loss_function: @loss_function,
220
+ optimizer: @optimizer,
221
+ learning_rate: @learning_rate,
222
+ history: @history,
223
+ metadata: {
224
+ total_params: count_params[:total],
225
+ trainable_params: count_params[:trainable],
226
+ layers_count: @layers.length
227
+ }
228
+ }
229
+
230
+ json_data = JSON.generate(model_data)
231
+ compressed_data = Zlib::Deflate.deflate(json_data)
232
+
233
+ File.open(expanded_path, 'wb') do |file|
234
+ file.write(compressed_data)
235
+ end
236
+
237
+ puts "Model saved to #{expanded_path}"
238
+ end
239
+
240
+ def self.load(filepath)
241
+ # Expandir ruta para manejar rutas relativas y absolutas
242
+ expanded_path = File.expand_path(filepath)
243
+
244
+ unless File.exist?(expanded_path)
245
+ raise "Model file not found: #{expanded_path}"
246
+ end
247
+
248
+ compressed_data = File.read(expanded_path, mode: 'rb')
249
+ json_data = Zlib::Inflate.inflate(compressed_data)
250
+ model_data = JSON.parse(json_data)
251
+
252
+ # Verificar compatibilidad
253
+ version = model_data['version'] || '1.0'
254
+ framework = model_data['framework'] || 'GRNexus'
255
+ source_lang = model_data['language'] || 'Unknown'
256
+
257
+ puts "Loading model: #{framework} v#{version} (created in #{source_lang})"
258
+
259
+ model = NeuralNetwork.new(
260
+ loss: model_data['loss_function'],
261
+ optimizer: model_data['optimizer'],
262
+ learning_rate: model_data['learning_rate'],
263
+ name: model_data['name'] || 'model'
264
+ )
265
+
266
+ model.deserialize_architecture(model_data['architecture'])
267
+
268
+ # Convertir history keys de strings a symbols si es necesario
269
+ if model_data['history'].is_a?(Hash)
270
+ model.history = model_data['history'].transform_keys(&:to_sym)
271
+ else
272
+ model.history = model_data['history']
273
+ end
274
+
275
+ puts "Model loaded from #{expanded_path}"
276
+ if model_data['metadata']
277
+ puts " Total params: #{model_data['metadata']['total_params']}"
278
+ puts " Layers: #{model_data['metadata']['layers_count']}"
279
+ end
280
+ model
281
+ end
282
+
283
+ def summary(line_length: 80)
284
+ puts "\n" + "=" * line_length
285
+ puts "Model: #{@name}"
286
+ puts "=" * line_length
287
+ puts "#{' ' * 20}Output Shape#{' ' * 10}Param #"
288
+ puts "-" * line_length
289
+
290
+ total_params = 0
291
+ trainable_params = 0
292
+
293
+ @layers.each_with_index do |layer, idx|
294
+ layer_name = layer.class.name.split('::').last
295
+ layer_name = "#{layer_name} (#{layer.activation.class.name.split('::').last})" if layer.respond_to?(:activation) && layer.activation
296
+
297
+ # Calculate output shape
298
+ output_shape = get_layer_output_shape(layer, idx)
299
+
300
+ # Calculate parameters
301
+ params = count_layer_params(layer)
302
+ total_params += params
303
+ trainable_params += params if layer.trainable?
304
+
305
+ # Format output
306
+ layer_info = "#{layer_name} (#{idx + 1})"
307
+ layer_info = layer_info.ljust(30)
308
+ layer_info += output_shape.ljust(25)
309
+ layer_info += params.to_s
310
+
311
+ puts layer_info
312
+ end
313
+
314
+ puts "=" * line_length
315
+ puts "Total params: #{total_params}"
316
+ puts "Trainable params: #{trainable_params}"
317
+ puts "Non-trainable params: #{total_params - trainable_params}"
318
+ puts "=" * line_length + "\n"
319
+ end
320
+
321
+ def to_json_architecture
322
+ {
323
+ name: @name,
324
+ layers: @layers.map.with_index do |layer, idx|
325
+ {
326
+ index: idx,
327
+ type: layer.class.name,
328
+ config: layer_config(layer),
329
+ trainable: layer.trainable?
330
+ }
331
+ end
332
+ }.to_json
333
+ end
334
+
335
+ def self.inspect_model(filepath)
336
+ """Inspecciona un modelo guardado sin cargarlo completamente"""
337
+ # Expandir ruta para manejar rutas relativas y absolutas
338
+ expanded_path = File.expand_path(filepath)
339
+
340
+ unless File.exist?(expanded_path)
341
+ puts "Error: Model file not found: #{expanded_path}"
342
+ return
343
+ end
344
+
345
+ compressed_data = File.read(expanded_path, mode: 'rb')
346
+ json_data = Zlib::Inflate.inflate(compressed_data)
347
+ model_data = JSON.parse(json_data)
348
+
349
+ puts "\n" + "=" * 80
350
+ puts "MODEL INSPECTION: #{filepath}"
351
+ puts "=" * 80
352
+ puts "Framework: #{model_data['framework'] || 'GRNexus'}"
353
+ puts "Version: #{model_data['version'] || '1.0'}"
354
+ puts "Language: #{model_data['language'] || 'Unknown'}"
355
+ puts "Name: #{model_data['name']}"
356
+ puts "Created: #{model_data['created_at']}"
357
+ puts "Loss Function: #{model_data['loss_function']}"
358
+ puts "Optimizer: #{model_data['optimizer']}"
359
+ puts "Learning Rate: #{model_data['learning_rate']}"
360
+
361
+ if model_data['metadata']
362
+ puts "\nMetadata:"
363
+ puts " Total Parameters: #{model_data['metadata']['total_params']}"
364
+ puts " Trainable Parameters: #{model_data['metadata']['trainable_params']}"
365
+ puts " Layers Count: #{model_data['metadata']['layers_count']}"
366
+ end
367
+
368
+ puts "\nArchitecture:"
369
+ puts "-" * 80
370
+ model_data['architecture'].each_with_index do |layer_data, idx|
371
+ layer_type = layer_data['type'].split('::').last
372
+ puts " Layer #{idx + 1}: #{layer_type}"
373
+ if layer_data['units']
374
+ puts " Units: #{layer_data['units']}"
375
+ end
376
+ if layer_data['activation']
377
+ puts " Activation: #{layer_data['activation'].split('::').last}"
378
+ end
379
+ puts " Trainable: #{layer_data['trainable']}"
380
+ end
381
+
382
+ if model_data['history'] && model_data['history']['loss']
383
+ puts "\nTraining History:"
384
+ puts " Epochs trained: #{model_data['history']['loss'].length}"
385
+ puts " Final loss: #{model_data['history']['loss'].last.round(6)}"
386
+ if model_data['history']['accuracy'] && model_data['history']['accuracy'].any?
387
+ puts " Final accuracy: #{model_data['history']['accuracy'].last.round(2)}%"
388
+ end
389
+ end
390
+
391
+ puts "=" * 80 + "\n"
392
+ end
393
+
394
+ def plot_history(metrics: ['loss', 'accuracy'])
395
+ # Text-based plotting
396
+ puts "\n" + "=" * 80
397
+ puts "Training History"
398
+ puts "=" * 80
399
+
400
+ metrics.each do |metric|
401
+ next unless @history[metric.to_sym]&.any?
402
+
403
+ puts "\n#{metric.capitalize}:"
404
+ values = @history[metric.to_sym]
405
+ min_val = values.min
406
+ max_val = values.max
407
+ range = max_val - min_val
408
+
409
+ values.each_with_index do |val, epoch|
410
+ normalized = range > 0 ? ((val - min_val) / range * 50).to_i : 25
411
+ bar = '█' * normalized
412
+ puts "Epoch #{epoch + 1}: #{bar} #{val.round(4)}"
413
+ end
414
+ end
415
+
416
+ puts "=" * 80 + "\n"
417
+ end
418
+
419
+ def count_params
420
+ total = 0
421
+ trainable = 0
422
+
423
+ @layers.each do |layer|
424
+ params = count_layer_params(layer)
425
+ total += params
426
+ trainable += params if layer.trainable?
427
+ end
428
+
429
+ { total: total, trainable: trainable, non_trainable: total - trainable }
430
+ end
431
+
432
+ def deserialize_architecture(architecture_data)
433
+ architecture_data.each do |layer_data|
434
+ # Manejar nombres de Python (lib.grnexus_layers.DenseLayer) y Ruby (GRNEXUSLayer::DenseLayer)
435
+ layer_type = layer_data['type']
436
+
437
+ # Extraer el nombre de la clase sin el módulo/namespace
438
+ # Python: "lib.grnexus_layers.DenseLayer" -> "DenseLayer"
439
+ # Ruby: "GRNEXUSLayer::DenseLayer" -> "DenseLayer"
440
+ if layer_type.include?('.')
441
+ # Formato Python
442
+ layer_class_name = layer_type.split('.').last
443
+ else
444
+ # Formato Ruby
445
+ layer_class_name = layer_type.split('::').last
446
+ end
447
+
448
+ begin
449
+ layer_class = GRNEXUSLayer.const_get(layer_class_name)
450
+ rescue NameError
451
+ puts "Warning: Could not find layer class #{layer_class_name}, skipping"
452
+ next
453
+ end
454
+
455
+ # Reconstruct layer based on type
456
+ layer = case layer_class_name
457
+ when 'DenseLayer'
458
+ # Manejar activación de Python o Ruby
459
+ activation = nil
460
+ if layer_data['activation']
461
+ activation_str = layer_data['activation']
462
+
463
+ # Extraer nombre de activación
464
+ # Python: "lib.grnexus_activations.ReLU" -> "ReLU"
465
+ # Ruby: "GRNEXUSActivations::ReLU" -> "ReLU"
466
+ if activation_str.include?('.')
467
+ activation_name = activation_str.split('.').last
468
+ else
469
+ activation_name = activation_str.split('::').last
470
+ end
471
+
472
+ begin
473
+ # Intentar primero en GRNEXUSActivations
474
+ activation = GRNEXUSActivations.const_get(activation_name).new
475
+ rescue NameError
476
+ # Intentar en GRNEXUSNormalization
477
+ begin
478
+ activation = GRNEXUSNormalization.const_get(activation_name).new
479
+ rescue NameError
480
+ puts "Warning: Could not find activation #{activation_name}"
481
+ activation = nil
482
+ end
483
+ end
484
+ end
485
+
486
+ l = layer_class.new(
487
+ units: layer_data['units'],
488
+ input_dim: layer_data['input_dim'],
489
+ activation: activation
490
+ )
491
+ l.weights = layer_data['weights'] if layer_data['weights']
492
+ l.biases = layer_data['biases'] if layer_data['biases']
493
+ l
494
+ when 'ActivationLayer'
495
+ activation_str = layer_data['activation']
496
+
497
+ # Extraer nombre de activación
498
+ if activation_str.include?('.')
499
+ activation_name = activation_str.split('.').last
500
+ else
501
+ activation_name = activation_str.split('::').last
502
+ end
503
+
504
+ activation = GRNEXUSActivations.const_get(activation_name).new rescue nil
505
+ layer_class.new(activation)
506
+ when 'DropoutLayer'
507
+ layer_class.new(rate: layer_data['rate'] || 0.5)
508
+ when 'BatchNormLayer'
509
+ l = layer_class.new(
510
+ epsilon: layer_data['epsilon'] || 1e-5,
511
+ momentum: layer_data['momentum'] || 0.1
512
+ )
513
+ l.gamma = layer_data['gamma'] || 1.0
514
+ l.beta = layer_data['beta'] || 0.0
515
+ # CRITICAL: Restore running statistics for BatchNorm inference
516
+ if layer_data['running_mean'] && layer_data['running_var']
517
+ l.instance_variable_set(:@running_mean, layer_data['running_mean'])
518
+ l.instance_variable_set(:@running_var, layer_data['running_var'])
519
+ else
520
+ # Inicializar running stats si no existen en el archivo
521
+ l.instance_variable_set(:@running_mean, nil)
522
+ l.instance_variable_set(:@running_var, nil)
523
+ end
524
+ l
525
+ else
526
+ # For other layers, try to instantiate with default parameters
527
+ layer_class.new
528
+ end
529
+
530
+ @layers << layer
531
+ end
532
+ end
533
+
534
+ private
535
+
536
+ def calculate_loss(predictions, targets)
537
+ case @loss_function
538
+ when 'mse'
539
+ mse_loss(predictions, targets)
540
+ when 'cross_entropy'
541
+ cross_entropy_loss(predictions, targets)
542
+ else
543
+ raise "Unknown loss function: #{@loss_function}"
544
+ end
545
+ end
546
+
547
+ def calculate_gradient(predictions, targets)
548
+ case @loss_function
549
+ when 'mse'
550
+ mse_gradient(predictions, targets)
551
+ when 'cross_entropy'
552
+ cross_entropy_gradient(predictions, targets)
553
+ else
554
+ raise "Unknown loss function: #{@loss_function}"
555
+ end
556
+ end
557
+
558
+ def mse_loss(predictions, targets)
559
+ total_loss = 0.0
560
+ predictions.each_with_index do |pred, i|
561
+ target = targets[i]
562
+ pred.each_with_index do |p, j|
563
+ diff = p - target[j]
564
+ total_loss += diff * diff
565
+ end
566
+ end
567
+ total_loss / (predictions.length * predictions[0].length)
568
+ end
569
+
570
+ def mse_gradient(predictions, targets)
571
+ gradients = []
572
+ predictions.each_with_index do |pred, i|
573
+ target = targets[i]
574
+ grad = pred.each_with_index.map do |p, j|
575
+ 2.0 * (p - target[j]) / (predictions.length * predictions[0].length)
576
+ end
577
+ gradients << grad
578
+ end
579
+ gradients
580
+ end
581
+
582
+ def cross_entropy_loss(predictions, targets)
583
+ total_loss = 0.0
584
+ epsilon = 1e-15
585
+
586
+ predictions.each_with_index do |pred, i|
587
+ target = targets[i]
588
+ pred.each_with_index do |p, j|
589
+ # Handle NaN and Infinity
590
+ p = 0.5 if p.nan? || p.infinite?
591
+
592
+ # Clip to prevent log(0)
593
+ p_clipped = [[p, epsilon].max, 1.0 - epsilon].min
594
+
595
+ # Only calculate if target is non-zero
596
+ if target[j] > 0
597
+ log_val = Math.log(p_clipped)
598
+ total_loss -= target[j] * log_val unless log_val.nan?
599
+ end
600
+ end
601
+ end
602
+
603
+ total_loss / predictions.length
604
+ end
605
+
606
+ def cross_entropy_gradient(predictions, targets)
607
+ gradients = []
608
+ predictions.each_with_index do |pred, i|
609
+ target = targets[i]
610
+ grad = pred.each_with_index.map { |p, j| p - target[j] }
611
+ gradients << grad
612
+ end
613
+ gradients
614
+ end
615
+
616
+ def calculate_accuracy(predictions, targets)
617
+ correct = 0
618
+ predictions.each_with_index do |pred, i|
619
+ # Handle NaN values in predictions
620
+ valid_pred = pred.map { |v| v.nan? || v.infinite? ? 0.0 : v }
621
+
622
+ # Find max indices safely
623
+ pred_class = valid_pred.each_with_index.max_by { |val, _| val }[1]
624
+ true_class = targets[i].each_with_index.max_by { |val, _| val }[1]
625
+
626
+ correct += 1 if pred_class == true_class
627
+ end
628
+ correct
629
+ end
630
+
631
+ def clean_nan(value)
632
+ case value
633
+ when Array
634
+ value.map { |v| clean_nan(v) }
635
+ when Float
636
+ value.nan? || value.infinite? ? 0.0 : value
637
+ else
638
+ value
639
+ end
640
+ end
641
+
642
+ def serialize_architecture
643
+ @layers.map do |layer|
644
+ layer_data = {
645
+ type: layer.class.name,
646
+ trainable: layer.trainable?
647
+ }
648
+
649
+ # Serialize layer-specific parameters
650
+ if layer.respond_to?(:weights) && layer.weights
651
+ layer_data[:weights] = clean_nan(layer.weights)
652
+ end
653
+
654
+ if layer.respond_to?(:biases) && layer.biases
655
+ layer_data[:biases] = clean_nan(layer.biases)
656
+ end
657
+
658
+ if layer.respond_to?(:units)
659
+ layer_data[:units] = layer.units
660
+ layer_data[:input_dim] = layer.input_dim if layer.respond_to?(:input_dim)
661
+ end
662
+
663
+ if layer.respond_to?(:activation) && layer.activation
664
+ layer_data[:activation] = layer.activation.class.name
665
+ end
666
+
667
+ if layer.respond_to?(:rate)
668
+ layer_data[:rate] = layer.rate
669
+ end
670
+
671
+ if layer.respond_to?(:epsilon)
672
+ layer_data[:epsilon] = layer.epsilon
673
+ layer_data[:momentum] = layer.momentum
674
+ layer_data[:gamma] = layer.gamma
675
+ layer_data[:beta] = layer.beta
676
+ # CRITICAL: Save running statistics for BatchNorm inference
677
+ if layer.instance_variable_defined?(:@running_mean)
678
+ layer_data[:running_mean] = clean_nan(layer.instance_variable_get(:@running_mean))
679
+ layer_data[:running_var] = clean_nan(layer.instance_variable_get(:@running_var))
680
+ end
681
+ end
682
+
683
+ layer_data
684
+ end
685
+ end
686
+
687
+ def count_layer_params(layer)
688
+ params = 0
689
+
690
+ if layer.respond_to?(:weights) && layer.weights
691
+ params += layer.weights.flatten.length
692
+ end
693
+
694
+ if layer.respond_to?(:biases) && layer.biases
695
+ params += layer.biases.length
696
+ end
697
+
698
+ if layer.respond_to?(:gamma)
699
+ params += 2 # gamma and beta
700
+ end
701
+
702
+ params
703
+ end
704
+
705
+ def get_layer_output_shape(layer, idx)
706
+ if layer.respond_to?(:units)
707
+ "(None, #{layer.units})"
708
+ elsif layer.respond_to?(:filters)
709
+ "(None, ?, ?, #{layer.filters})"
710
+ elsif layer.is_a?(GRNEXUSLayer::DropoutLayer) || layer.is_a?(GRNEXUSLayer::BatchNormLayer)
711
+ # Dropout y BatchNorm mantienen la misma forma que la última capa con units
712
+ # Buscar hacia atrás la última capa con units
713
+ (idx - 1).downto(0) do |i|
714
+ if @layers[i].respond_to?(:units)
715
+ return "(None, #{@layers[i].units})"
716
+ end
717
+ end
718
+ "(None, ?)"
719
+ else
720
+ "(None, ?)"
721
+ end
722
+ end
723
+
724
+ def layer_config(layer)
725
+ config = {}
726
+
727
+ if layer.respond_to?(:units)
728
+ config[:units] = layer.units
729
+ config[:input_dim] = layer.input_dim if layer.respond_to?(:input_dim)
730
+ end
731
+
732
+ if layer.respond_to?(:activation) && layer.activation
733
+ config[:activation] = layer.activation.class.name
734
+ end
735
+
736
+ if layer.respond_to?(:rate)
737
+ config[:rate] = layer.rate
738
+ end
739
+
740
+ config
741
+ end
742
+ end
743
+ end