grnexus 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,462 @@
1
+ require 'ffi'
2
+ require 'rbconfig'
3
+
4
+ module GRNEXUSActivations
5
+ extend FFI::Library
6
+
7
+ # Detectar sistema operativo
8
+ def self.detect_library
9
+ script_dir = File.dirname(File.expand_path(__FILE__))
10
+ os = RbConfig::CONFIG['host_os']
11
+
12
+ case os
13
+ when /mswin|mingw|cygwin/
14
+ File.join(script_dir, '..', 'exports', 'Windows', 'activations.dll')
15
+ when /darwin/
16
+ File.join(script_dir, '..', 'exports', 'Mac', 'activations.dylib')
17
+ when /linux/
18
+ File.join(script_dir, '..', 'exports', 'Linux', 'activations.so')
19
+ else
20
+ raise "Sistema operativo no soportado: #{os}"
21
+ end
22
+ end
23
+
24
+ # Cargar biblioteca
25
+ ffi_lib detect_library
26
+
27
+ # Definir estructura GRNexusData
28
+ class GRNexusData < FFI::Struct
29
+ layout :data, :pointer,
30
+ :type, :int,
31
+ :size, :size_t,
32
+ :stride, :size_t,
33
+ :dims, [:size_t, 3]
34
+ end
35
+
36
+ # Definir funciones disponibles
37
+ FUNCTIONS = {
38
+ 'Linear' => [:Linear, [:pointer, :pointer, :bool], :int],
39
+ 'Step' => [:Step, [:pointer, :pointer, :bool], :int],
40
+ 'Sigmoid' => [:Sigmoid, [:pointer, :pointer, :bool], :int],
41
+ 'Tanh' => [:Tanh, [:pointer, :pointer, :bool], :int],
42
+ 'ReLU' => [:ReLU, [:pointer, :pointer, :bool], :int],
43
+ 'LeakyReLU' => [:LeakyReLU, [:pointer, :pointer, :bool, :double], :int],
44
+ 'PReLU' => [:PReLU, [:pointer, :pointer, :bool, :pointer], :int],
45
+ 'ELU' => [:ELU, [:pointer, :pointer, :bool, :double], :int],
46
+ 'SELU' => [:SELU, [:pointer, :pointer, :bool], :int],
47
+ 'Softplus' => [:Softplus, [:pointer, :pointer, :bool], :int],
48
+ 'Softsign' => [:Softsign, [:pointer, :pointer, :bool], :int],
49
+ 'HardSigmoid' => [:HardSigmoid, [:pointer, :pointer, :bool], :int],
50
+ 'HardTanh' => [:HardTanh, [:pointer, :pointer, :bool], :int],
51
+ 'ThresholdedReLU' => [:ThresholdedReLU, [:pointer, :pointer, :bool, :double], :int],
52
+ 'GELU' => [:GELU, [:pointer, :pointer, :bool], :int],
53
+ 'Swish' => [:Swish, [:pointer, :pointer, :bool], :int],
54
+ 'Mish' => [:Mish, [:pointer, :pointer, :bool], :int],
55
+ 'LiSHT' => [:LiSHT, [:pointer, :pointer, :bool], :int],
56
+ 'ReLUSquared' => [:ReLUSquared, [:pointer, :pointer, :bool], :int],
57
+ 'SquaredReLU' => [:SquaredReLU, [:pointer, :pointer, :bool], :int],
58
+ 'CELU' => [:CELU, [:pointer, :pointer, :bool, :double], :int],
59
+ 'HardShrink' => [:HardShrink, [:pointer, :pointer, :bool, :double], :int],
60
+ 'SoftShrink' => [:SoftShrink, [:pointer, :pointer, :bool, :double], :int],
61
+ 'TanhShrink' => [:TanhShrink, [:pointer, :pointer, :bool], :int],
62
+ 'ReLU6' => [:ReLU6, [:pointer, :pointer, :bool], :int],
63
+ 'HardSwish' => [:HardSwish, [:pointer, :pointer, :bool], :int],
64
+ 'SiLU' => [:SiLU, [:pointer, :pointer, :bool], :int],
65
+ 'GLU' => [:GLU, [:pointer, :pointer, :bool, :size_t], :int],
66
+ 'BReLU' => [:BReLU, [:pointer, :pointer, :bool], :int],
67
+ 'ARelu' => [:ARelu, [:pointer, :pointer, :bool, :double, :double], :int],
68
+ 'FReLU' => [:FReLU, [:pointer, :pointer, :bool, :double], :int],
69
+ 'Snake' => [:Snake, [:pointer, :pointer, :bool, :double], :int],
70
+ 'SnakeBeta' => [:SnakeBeta, [:pointer, :pointer, :bool, :double, :double], :int],
71
+ 'ISRU' => [:ISRU, [:pointer, :pointer, :bool, :double], :int],
72
+ 'ISRLU' => [:ISRLU, [:pointer, :pointer, :bool, :double], :int],
73
+ 'Maxout' => [:Maxout, [:pointer, :pointer, :bool, :size_t], :int],
74
+ 'Minout' => [:Minout, [:pointer, :pointer, :bool, :size_t], :int],
75
+ }
76
+
77
+ # Adjuntar funciones al módulo
78
+ FUNCTIONS.each do |name, (func, args, ret)|
79
+ attach_function func, args, ret
80
+ end
81
+
82
+ # Función para crear GRNexusData
83
+ def self.create_grnexus_data(array_or_scalar)
84
+ # Convertir a array si es un número
85
+ values = array_or_scalar.is_a?(Array) ? array_or_scalar : [array_or_scalar]
86
+ size = values.length
87
+
88
+ # Crear un buffer con los datos
89
+ buffer = FFI::MemoryPointer.new(:double, size)
90
+ buffer.write_array_of_double(values)
91
+
92
+ data = FFI::MemoryPointer.new(GRNexusData.size)
93
+ struct_instance = GRNexusData.new(data)
94
+ struct_instance[:data] = buffer
95
+ struct_instance[:type] = size == 1 ? 0 : 1 # SCALAR=0, ARRAY=1
96
+ struct_instance[:size] = size
97
+ struct_instance[:stride] = 1
98
+ struct_instance[:dims][0] = size
99
+ struct_instance[:dims][1] = 0
100
+ struct_instance[:dims][2] = 0
101
+
102
+ [data, buffer]
103
+ end
104
+
105
+ # Función para crear GRNexusData para salida
106
+ def self.create_output_grnexus_data(size)
107
+ # Crear un buffer vacío modificable
108
+ buffer = FFI::MemoryPointer.new(:double, size)
109
+
110
+ data = FFI::MemoryPointer.new(GRNexusData.size)
111
+ struct_instance = GRNexusData.new(data)
112
+ struct_instance[:data] = buffer
113
+ struct_instance[:type] = 1 # ARRAY
114
+ struct_instance[:size] = size
115
+ struct_instance[:stride] = 1
116
+ struct_instance[:dims][0] = size
117
+ struct_instance[:dims][1] = 0
118
+ struct_instance[:dims][2] = 0
119
+
120
+ [data, buffer]
121
+ end
122
+
123
+ # Función para leer datos de GRNexusData
124
+ def self.read_grnexus_data(original_ptr, size)
125
+ original_ptr.read_array_of_double(size)
126
+ end
127
+
128
+ class ActivationFunction
129
+ def call(x, derivative: false)
130
+ raise NotImplementedError, "Debes implementar la función de activación"
131
+ end
132
+ end
133
+
134
+ class BaseActivation < ActivationFunction
135
+ protected
136
+
137
+ def execute_activation(func_name, input_values, derivative: false, **kwargs)
138
+ input_data, input_buffer = GRNEXUSActivations.create_grnexus_data(input_values)
139
+ output_size = input_values.is_a?(Array) ? input_values.length : 1
140
+ output_data, output_buffer = GRNEXUSActivations.create_output_grnexus_data(output_size)
141
+
142
+ args = [input_data, output_data, derivative]
143
+
144
+ # Agregar argumentos específicos según la función
145
+ case func_name
146
+ when :LeakyReLU, :ELU, :ThresholdedReLU, :CELU, :HardShrink, :SoftShrink, :FReLU, :Snake, :ISRU, :ISRLU
147
+ args << kwargs[:param1] || 0.01
148
+ when :PReLU
149
+ alpha_ptr = FFI::MemoryPointer.new(:double)
150
+ alpha_ptr.write_double(kwargs[:param1] || 0.01)
151
+ args << alpha_ptr
152
+ when :ARelu
153
+ args << (kwargs[:param1] || 0.01) << (kwargs[:param2] || 1.0)
154
+ when :SnakeBeta
155
+ args << (kwargs[:param1] || 1.0) << (kwargs[:param2] || 1.0)
156
+ when :GLU
157
+ args << (kwargs[:dim] || 0)
158
+ when :Maxout, :Minout
159
+ args << (kwargs[:num_pieces] || 2)
160
+ end
161
+
162
+ # Llamar la función en el módulo GRNEXUSActivations, no en la instancia
163
+ result = GRNEXUSActivations.send(func_name, *args)
164
+
165
+ if result == 0
166
+ output_buffer.read_array_of_double(output_size)
167
+ else
168
+ raise "Función #{func_name} falló con código: #{result}"
169
+ end
170
+ ensure
171
+ # FFI::MemoryPointer se libera automáticamente con GC
172
+ end
173
+ end
174
+
175
+ class Linear < BaseActivation
176
+ def call(x, derivative: false)
177
+ execute_activation(:Linear, x, derivative: derivative)
178
+ end
179
+ end
180
+
181
+ class Step < BaseActivation
182
+ def call(x, derivative: false)
183
+ execute_activation(:Step, x, derivative: derivative)
184
+ end
185
+ end
186
+
187
+ class Sigmoid < BaseActivation
188
+ def call(x, derivative: false)
189
+ execute_activation(:Sigmoid, x, derivative: derivative)
190
+ end
191
+ end
192
+
193
+ class Tanh < BaseActivation
194
+ def call(x, derivative: false)
195
+ execute_activation(:Tanh, x, derivative: derivative)
196
+ end
197
+ end
198
+
199
+ class ReLU < BaseActivation
200
+ def call(x, derivative: false)
201
+ execute_activation(:ReLU, x, derivative: derivative)
202
+ end
203
+ end
204
+
205
+ class LeakyReLU < BaseActivation
206
+ def initialize(alpha: 0.01)
207
+ @alpha = alpha
208
+ end
209
+
210
+ def call(x, derivative: false)
211
+ execute_activation(:LeakyReLU, x, derivative: derivative, param1: @alpha)
212
+ end
213
+ end
214
+
215
+ class PReLU < BaseActivation
216
+ def initialize(alpha: 0.01)
217
+ @alpha = alpha
218
+ end
219
+
220
+ def call(x, derivative: false)
221
+ execute_activation(:PReLU, x, derivative: derivative, param1: @alpha)
222
+ end
223
+ end
224
+
225
+ class ELU < BaseActivation
226
+ def initialize(alpha: 1.0)
227
+ @alpha = alpha
228
+ end
229
+
230
+ def call(x, derivative: false)
231
+ execute_activation(:ELU, x, derivative: derivative, param1: @alpha)
232
+ end
233
+ end
234
+
235
+ class SELU < BaseActivation
236
+ def call(x, derivative: false)
237
+ execute_activation(:SELU, x, derivative: derivative)
238
+ end
239
+ end
240
+
241
+ class Softplus < BaseActivation
242
+ def call(x, derivative: false)
243
+ execute_activation(:Softplus, x, derivative: derivative)
244
+ end
245
+ end
246
+
247
+ class Softsign < BaseActivation
248
+ def call(x, derivative: false)
249
+ execute_activation(:Softsign, x, derivative: derivative)
250
+ end
251
+ end
252
+
253
+ class HardSigmoid < BaseActivation
254
+ def call(x, derivative: false)
255
+ execute_activation(:HardSigmoid, x, derivative: derivative)
256
+ end
257
+ end
258
+
259
+ class HardTanh < BaseActivation
260
+ def call(x, derivative: false)
261
+ execute_activation(:HardTanh, x, derivative: derivative)
262
+ end
263
+ end
264
+
265
+ class ThresholdedReLU < BaseActivation
266
+ def initialize(theta: 1.0)
267
+ @theta = theta
268
+ end
269
+
270
+ def call(x, derivative: false)
271
+ execute_activation(:ThresholdedReLU, x, derivative: derivative, param1: @theta)
272
+ end
273
+ end
274
+
275
+ class GELU < BaseActivation
276
+ def call(x, derivative: false)
277
+ execute_activation(:GELU, x, derivative: derivative)
278
+ end
279
+ end
280
+
281
+ class Swish < BaseActivation
282
+ def call(x, derivative: false)
283
+ execute_activation(:Swish, x, derivative: derivative)
284
+ end
285
+ end
286
+
287
+ class Mish < BaseActivation
288
+ def call(x, derivative: false)
289
+ execute_activation(:Mish, x, derivative: derivative)
290
+ end
291
+ end
292
+
293
+ class LiSHT < BaseActivation
294
+ def call(x, derivative: false)
295
+ execute_activation(:LiSHT, x, derivative: derivative)
296
+ end
297
+ end
298
+
299
+ class ReLUSquared < BaseActivation
300
+ def call(x, derivative: false)
301
+ execute_activation(:ReLUSquared, x, derivative: derivative)
302
+ end
303
+ end
304
+
305
+ class SquaredReLU < BaseActivation
306
+ def call(x, derivative: false)
307
+ execute_activation(:SquaredReLU, x, derivative: derivative)
308
+ end
309
+ end
310
+
311
+ class CELU < BaseActivation
312
+ def initialize(alpha: 1.0)
313
+ @alpha = alpha
314
+ end
315
+
316
+ def call(x, derivative: false)
317
+ execute_activation(:CELU, x, derivative: derivative, param1: @alpha)
318
+ end
319
+ end
320
+
321
+ class HardShrink < BaseActivation
322
+ def initialize(lambda: 0.5)
323
+ @lambda = lambda
324
+ end
325
+
326
+ def call(x, derivative: false)
327
+ execute_activation(:HardShrink, x, derivative: derivative, param1: @lambda)
328
+ end
329
+ end
330
+
331
+ class SoftShrink < BaseActivation
332
+ def initialize(lambda: 0.5)
333
+ @lambda = lambda
334
+ end
335
+
336
+ def call(x, derivative: false)
337
+ execute_activation(:SoftShrink, x, derivative: derivative, param1: @lambda)
338
+ end
339
+ end
340
+
341
+ class TanhShrink < BaseActivation
342
+ def call(x, derivative: false)
343
+ execute_activation(:TanhShrink, x, derivative: derivative)
344
+ end
345
+ end
346
+
347
+ class ReLU6 < BaseActivation
348
+ def call(x, derivative: false)
349
+ execute_activation(:ReLU6, x, derivative: derivative)
350
+ end
351
+ end
352
+
353
+ class HardSwish < BaseActivation
354
+ def call(x, derivative: false)
355
+ execute_activation(:HardSwish, x, derivative: derivative)
356
+ end
357
+ end
358
+
359
+ class SiLU < BaseActivation
360
+ def call(x, derivative: false)
361
+ execute_activation(:SiLU, x, derivative: derivative)
362
+ end
363
+ end
364
+
365
+ class GLU < BaseActivation
366
+ def initialize(dim: 1)
367
+ @dim = dim
368
+ end
369
+
370
+ def call(x, derivative: false)
371
+ execute_activation(:GLU, x, derivative: derivative, dim: @dim)
372
+ end
373
+ end
374
+
375
+ class BReLU < BaseActivation
376
+ def call(x, derivative: false)
377
+ execute_activation(:BReLU, x, derivative: derivative)
378
+ end
379
+ end
380
+
381
+ class ARelu < BaseActivation
382
+ def initialize(alpha: 0.0, beta: 1.0)
383
+ @alpha = alpha
384
+ @beta = beta
385
+ end
386
+
387
+ def call(x, derivative: false)
388
+ execute_activation(:ARelu, x, derivative: derivative, param1: @alpha, param2: @beta)
389
+ end
390
+ end
391
+
392
+ class FReLU < BaseActivation
393
+ def initialize(alpha: 1.0)
394
+ @alpha = alpha
395
+ end
396
+
397
+ def call(x, derivative: false)
398
+ execute_activation(:FReLU, x, derivative: derivative, param1: @alpha)
399
+ end
400
+ end
401
+
402
+ class Snake < BaseActivation
403
+ def initialize(alpha: 1.0)
404
+ @alpha = alpha
405
+ end
406
+
407
+ def call(x, derivative: false)
408
+ execute_activation(:Snake, x, derivative: derivative, param1: @alpha)
409
+ end
410
+ end
411
+
412
+ class SnakeBeta < BaseActivation
413
+ def initialize(alpha: 1.0, beta: 1.0)
414
+ @alpha = alpha
415
+ @beta = beta
416
+ end
417
+
418
+ def call(x, derivative: false)
419
+ execute_activation(:SnakeBeta, x, derivative: derivative, param1: @alpha, param2: @beta)
420
+ end
421
+ end
422
+
423
+ class ISRU < BaseActivation
424
+ def initialize(alpha: 1.0)
425
+ @alpha = alpha
426
+ end
427
+
428
+ def call(x, derivative: false)
429
+ execute_activation(:ISRU, x, derivative: derivative, param1: @alpha)
430
+ end
431
+ end
432
+
433
+ class ISRLU < BaseActivation
434
+ def initialize(alpha: 1.0)
435
+ @alpha = alpha
436
+ end
437
+
438
+ def call(x, derivative: false)
439
+ execute_activation(:ISRLU, x, derivative: derivative, param1: @alpha)
440
+ end
441
+ end
442
+
443
+ class Maxout < BaseActivation
444
+ def initialize(num_pieces: 2)
445
+ @num_pieces = num_pieces
446
+ end
447
+
448
+ def call(x, derivative: false)
449
+ execute_activation(:Maxout, x, derivative: derivative, num_pieces: @num_pieces)
450
+ end
451
+ end
452
+
453
+ class Minout < BaseActivation
454
+ def initialize(num_pieces: 2)
455
+ @num_pieces = num_pieces
456
+ end
457
+
458
+ def call(x, derivative: false)
459
+ execute_activation(:Minout, x, derivative: derivative, num_pieces: @num_pieces)
460
+ end
461
+ end
462
+ end
@@ -0,0 +1,249 @@
1
+ # Callbacks system for training
2
+ module GRNEXUSCallbacks
3
+ class Callback
4
+ def on_train_begin(logs = {})
5
+ end
6
+
7
+ def on_train_end(logs = {})
8
+ end
9
+
10
+ def on_epoch_begin(epoch, logs = {})
11
+ end
12
+
13
+ def on_epoch_end(epoch, logs = {})
14
+ end
15
+
16
+ def on_batch_begin(batch, logs = {})
17
+ end
18
+
19
+ def on_batch_end(batch, logs = {})
20
+ end
21
+ end
22
+
23
+ class EarlyStopping < Callback
24
+ attr_reader :stopped_epoch
25
+
26
+ def initialize(monitor: 'loss', patience: 10, min_delta: 0.0001, mode: 'min', verbose: true)
27
+ @monitor = monitor
28
+ @patience = patience
29
+ @min_delta = min_delta
30
+ @mode = mode
31
+ @verbose = verbose
32
+ @wait = 0
33
+ @stopped_epoch = 0
34
+ @best = mode == 'min' ? Float::INFINITY : -Float::INFINITY
35
+ @stop_training = false
36
+ end
37
+
38
+ def on_train_begin(logs = {})
39
+ @wait = 0
40
+ @stopped_epoch = 0
41
+ @best = @mode == 'min' ? Float::INFINITY : -Float::INFINITY
42
+ @stop_training = false
43
+ end
44
+
45
+ def on_epoch_end(epoch, logs = {})
46
+ current = logs[@monitor.to_sym]
47
+ return unless current
48
+
49
+ if monitor_improved?(current)
50
+ @best = current
51
+ @wait = 0
52
+ else
53
+ @wait += 1
54
+ if @wait >= @patience
55
+ @stopped_epoch = epoch
56
+ @stop_training = true
57
+ puts "Early stopping triggered at epoch #{epoch + 1}" if @verbose
58
+ end
59
+ end
60
+ end
61
+
62
+ def stop_training?
63
+ @stop_training
64
+ end
65
+
66
+ private
67
+
68
+ def monitor_improved?(current)
69
+ if @mode == 'min'
70
+ current < @best - @min_delta
71
+ else
72
+ current > @best + @min_delta
73
+ end
74
+ end
75
+ end
76
+
77
+ class ModelCheckpoint < Callback
78
+ def initialize(filepath:, monitor: 'loss', mode: 'min', save_best_only: true, verbose: true)
79
+ @filepath = filepath
80
+ @monitor = monitor
81
+ @mode = mode
82
+ @save_best_only = save_best_only
83
+ @verbose = verbose
84
+ @best = mode == 'min' ? Float::INFINITY : -Float::INFINITY
85
+ end
86
+
87
+ def on_epoch_end(epoch, logs = {})
88
+ current = logs[@monitor.to_sym]
89
+ return unless current
90
+
91
+ should_save = !@save_best_only || monitor_improved?(current)
92
+
93
+ if should_save
94
+ @best = current if monitor_improved?(current)
95
+ filepath = @filepath.gsub('{epoch}', (epoch + 1).to_s)
96
+ logs[:model].save(filepath)
97
+ puts "Epoch #{epoch + 1}: #{@monitor} improved to #{current.round(6)}, saving model to #{filepath}" if @verbose
98
+ end
99
+ end
100
+
101
+ private
102
+
103
+ def monitor_improved?(current)
104
+ if @mode == 'min'
105
+ current < @best
106
+ else
107
+ current > @best
108
+ end
109
+ end
110
+ end
111
+
112
+ class LearningRateScheduler < Callback
113
+ def initialize(schedule:, verbose: true)
114
+ @schedule = schedule
115
+ @verbose = verbose
116
+ end
117
+
118
+ def on_epoch_begin(epoch, logs = {})
119
+ new_lr = if @schedule.is_a?(Proc)
120
+ @schedule.call(epoch)
121
+ elsif @schedule.is_a?(Hash)
122
+ @schedule[epoch] || logs[:model].learning_rate
123
+ else
124
+ logs[:model].learning_rate
125
+ end
126
+
127
+ logs[:model].learning_rate = new_lr
128
+ puts "Epoch #{epoch + 1}: Learning rate set to #{new_lr}" if @verbose
129
+ end
130
+ end
131
+
132
+ class ReduceLROnPlateau < Callback
133
+ def initialize(monitor: 'loss', factor: 0.5, patience: 5, min_delta: 0.0001,
134
+ min_lr: 1e-7, mode: 'min', verbose: true)
135
+ @monitor = monitor
136
+ @factor = factor
137
+ @patience = patience
138
+ @min_delta = min_delta
139
+ @min_lr = min_lr
140
+ @mode = mode
141
+ @verbose = verbose
142
+ @wait = 0
143
+ @best = mode == 'min' ? Float::INFINITY : -Float::INFINITY
144
+ end
145
+
146
+ def on_train_begin(logs = {})
147
+ @wait = 0
148
+ @best = @mode == 'min' ? Float::INFINITY : -Float::INFINITY
149
+ end
150
+
151
+ def on_epoch_end(epoch, logs = {})
152
+ current = logs[@monitor.to_sym]
153
+ return unless current
154
+
155
+ if monitor_improved?(current)
156
+ @best = current
157
+ @wait = 0
158
+ else
159
+ @wait += 1
160
+ if @wait >= @patience
161
+ old_lr = logs[:model].learning_rate
162
+ new_lr = [old_lr * @factor, @min_lr].max
163
+
164
+ if new_lr < old_lr
165
+ logs[:model].learning_rate = new_lr
166
+ puts "Epoch #{epoch + 1}: Reducing learning rate from #{old_lr} to #{new_lr}" if @verbose
167
+ @wait = 0
168
+ end
169
+ end
170
+ end
171
+ end
172
+
173
+ private
174
+
175
+ def monitor_improved?(current)
176
+ if @mode == 'min'
177
+ current < @best - @min_delta
178
+ else
179
+ current > @best + @min_delta
180
+ end
181
+ end
182
+ end
183
+
184
+ class CSVLogger < Callback
185
+ def initialize(filename:, separator: ',', append: false)
186
+ @filename = filename
187
+ @separator = separator
188
+ @append = append
189
+ @file = nil
190
+ @keys = nil
191
+ end
192
+
193
+ def on_train_begin(logs = {})
194
+ mode = @append ? 'a' : 'w'
195
+ @file = File.open(@filename, mode)
196
+ @keys = nil
197
+ end
198
+
199
+ def on_epoch_end(epoch, logs = {})
200
+ logs_copy = logs.dup
201
+ logs_copy.delete(:model)
202
+ logs_copy[:epoch] = epoch + 1
203
+
204
+ if @keys.nil?
205
+ @keys = logs_copy.keys
206
+ @file.puts @keys.join(@separator)
207
+ end
208
+
209
+ values = @keys.map { |key| logs_copy[key] || '' }
210
+ @file.puts values.join(@separator)
211
+ @file.flush
212
+ end
213
+
214
+ def on_train_end(logs = {})
215
+ @file.close if @file
216
+ end
217
+ end
218
+
219
+ class ProgbarLogger < Callback
220
+ def initialize(count_mode: 'steps')
221
+ @count_mode = count_mode
222
+ end
223
+
224
+ def on_train_begin(logs = {})
225
+ @epochs = logs[:epochs] || 1
226
+ end
227
+
228
+ def on_epoch_begin(epoch, logs = {})
229
+ puts "\nEpoch #{epoch + 1}/#{@epochs}"
230
+ @seen = 0
231
+ @total = logs[:steps] || 0
232
+ end
233
+
234
+ def on_batch_end(batch, logs = {})
235
+ @seen += 1
236
+ progress = (@seen.to_f / @total * 100).round(1)
237
+ bar_length = 30
238
+ filled = (bar_length * @seen / @total).to_i
239
+ bar = '=' * filled + '>' + '.' * (bar_length - filled - 1)
240
+
241
+ print "\r#{@seen}/#{@total} [#{bar}] #{progress}%"
242
+ $stdout.flush
243
+ end
244
+
245
+ def on_epoch_end(epoch, logs = {})
246
+ puts ""
247
+ end
248
+ end
249
+ end