grnexus 1.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE +96 -0
- data/README.md +1105 -0
- data/exports/Linux/libgrnexus.so +0 -0
- data/exports/Mac/activations.dylib +0 -0
- data/exports/Mac/grnexus_core.dylib +0 -0
- data/exports/Mac/machine_learning.dylib +0 -0
- data/exports/Mac/normalization.dylib +0 -0
- data/exports/Mac/numeric_proccessing.dylib +0 -0
- data/exports/Mac/text_processing.dylib +0 -0
- data/exports/Windows/activations.dll +0 -0
- data/exports/Windows/grnexus_core.dll +0 -0
- data/exports/Windows/machine_learning.dll +0 -0
- data/exports/Windows/normalization.dll +0 -0
- data/exports/Windows/numeric_proccessing.dll +0 -0
- data/exports/Windows/text_processing.dll +0 -0
- data/lib/grnexus.rb +743 -0
- data/lib/grnexus_activations.rb +462 -0
- data/lib/grnexus_callbacks.rb +249 -0
- data/lib/grnexus_core.rb +130 -0
- data/lib/grnexus_layers.rb +1103 -0
- data/lib/grnexus_machine_learning.rb +591 -0
- data/lib/grnexus_normalization.rb +319 -0
- data/lib/grnexus_numeric_proccessing.rb +722 -0
- data/lib/grnexus_text_proccessing.rb +295 -0
- metadata +149 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
require 'ffi'
|
|
2
|
+
require 'rbconfig'
|
|
3
|
+
|
|
4
|
+
module GRNEXUSActivations
|
|
5
|
+
extend FFI::Library
|
|
6
|
+
|
|
7
|
+
# Detectar sistema operativo
|
|
8
|
+
def self.detect_library
|
|
9
|
+
script_dir = File.dirname(File.expand_path(__FILE__))
|
|
10
|
+
os = RbConfig::CONFIG['host_os']
|
|
11
|
+
|
|
12
|
+
case os
|
|
13
|
+
when /mswin|mingw|cygwin/
|
|
14
|
+
File.join(script_dir, '..', 'exports', 'Windows', 'activations.dll')
|
|
15
|
+
when /darwin/
|
|
16
|
+
File.join(script_dir, '..', 'exports', 'Mac', 'activations.dylib')
|
|
17
|
+
when /linux/
|
|
18
|
+
File.join(script_dir, '..', 'exports', 'Linux', 'activations.so')
|
|
19
|
+
else
|
|
20
|
+
raise "Sistema operativo no soportado: #{os}"
|
|
21
|
+
end
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
# Cargar biblioteca
|
|
25
|
+
ffi_lib detect_library
|
|
26
|
+
|
|
27
|
+
# Definir estructura GRNexusData
|
|
28
|
+
class GRNexusData < FFI::Struct
|
|
29
|
+
layout :data, :pointer,
|
|
30
|
+
:type, :int,
|
|
31
|
+
:size, :size_t,
|
|
32
|
+
:stride, :size_t,
|
|
33
|
+
:dims, [:size_t, 3]
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
# Definir funciones disponibles
|
|
37
|
+
FUNCTIONS = {
|
|
38
|
+
'Linear' => [:Linear, [:pointer, :pointer, :bool], :int],
|
|
39
|
+
'Step' => [:Step, [:pointer, :pointer, :bool], :int],
|
|
40
|
+
'Sigmoid' => [:Sigmoid, [:pointer, :pointer, :bool], :int],
|
|
41
|
+
'Tanh' => [:Tanh, [:pointer, :pointer, :bool], :int],
|
|
42
|
+
'ReLU' => [:ReLU, [:pointer, :pointer, :bool], :int],
|
|
43
|
+
'LeakyReLU' => [:LeakyReLU, [:pointer, :pointer, :bool, :double], :int],
|
|
44
|
+
'PReLU' => [:PReLU, [:pointer, :pointer, :bool, :pointer], :int],
|
|
45
|
+
'ELU' => [:ELU, [:pointer, :pointer, :bool, :double], :int],
|
|
46
|
+
'SELU' => [:SELU, [:pointer, :pointer, :bool], :int],
|
|
47
|
+
'Softplus' => [:Softplus, [:pointer, :pointer, :bool], :int],
|
|
48
|
+
'Softsign' => [:Softsign, [:pointer, :pointer, :bool], :int],
|
|
49
|
+
'HardSigmoid' => [:HardSigmoid, [:pointer, :pointer, :bool], :int],
|
|
50
|
+
'HardTanh' => [:HardTanh, [:pointer, :pointer, :bool], :int],
|
|
51
|
+
'ThresholdedReLU' => [:ThresholdedReLU, [:pointer, :pointer, :bool, :double], :int],
|
|
52
|
+
'GELU' => [:GELU, [:pointer, :pointer, :bool], :int],
|
|
53
|
+
'Swish' => [:Swish, [:pointer, :pointer, :bool], :int],
|
|
54
|
+
'Mish' => [:Mish, [:pointer, :pointer, :bool], :int],
|
|
55
|
+
'LiSHT' => [:LiSHT, [:pointer, :pointer, :bool], :int],
|
|
56
|
+
'ReLUSquared' => [:ReLUSquared, [:pointer, :pointer, :bool], :int],
|
|
57
|
+
'SquaredReLU' => [:SquaredReLU, [:pointer, :pointer, :bool], :int],
|
|
58
|
+
'CELU' => [:CELU, [:pointer, :pointer, :bool, :double], :int],
|
|
59
|
+
'HardShrink' => [:HardShrink, [:pointer, :pointer, :bool, :double], :int],
|
|
60
|
+
'SoftShrink' => [:SoftShrink, [:pointer, :pointer, :bool, :double], :int],
|
|
61
|
+
'TanhShrink' => [:TanhShrink, [:pointer, :pointer, :bool], :int],
|
|
62
|
+
'ReLU6' => [:ReLU6, [:pointer, :pointer, :bool], :int],
|
|
63
|
+
'HardSwish' => [:HardSwish, [:pointer, :pointer, :bool], :int],
|
|
64
|
+
'SiLU' => [:SiLU, [:pointer, :pointer, :bool], :int],
|
|
65
|
+
'GLU' => [:GLU, [:pointer, :pointer, :bool, :size_t], :int],
|
|
66
|
+
'BReLU' => [:BReLU, [:pointer, :pointer, :bool], :int],
|
|
67
|
+
'ARelu' => [:ARelu, [:pointer, :pointer, :bool, :double, :double], :int],
|
|
68
|
+
'FReLU' => [:FReLU, [:pointer, :pointer, :bool, :double], :int],
|
|
69
|
+
'Snake' => [:Snake, [:pointer, :pointer, :bool, :double], :int],
|
|
70
|
+
'SnakeBeta' => [:SnakeBeta, [:pointer, :pointer, :bool, :double, :double], :int],
|
|
71
|
+
'ISRU' => [:ISRU, [:pointer, :pointer, :bool, :double], :int],
|
|
72
|
+
'ISRLU' => [:ISRLU, [:pointer, :pointer, :bool, :double], :int],
|
|
73
|
+
'Maxout' => [:Maxout, [:pointer, :pointer, :bool, :size_t], :int],
|
|
74
|
+
'Minout' => [:Minout, [:pointer, :pointer, :bool, :size_t], :int],
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
# Adjuntar funciones al módulo
|
|
78
|
+
FUNCTIONS.each do |name, (func, args, ret)|
|
|
79
|
+
attach_function func, args, ret
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
# Función para crear GRNexusData
|
|
83
|
+
def self.create_grnexus_data(array_or_scalar)
|
|
84
|
+
# Convertir a array si es un número
|
|
85
|
+
values = array_or_scalar.is_a?(Array) ? array_or_scalar : [array_or_scalar]
|
|
86
|
+
size = values.length
|
|
87
|
+
|
|
88
|
+
# Crear un buffer con los datos
|
|
89
|
+
buffer = FFI::MemoryPointer.new(:double, size)
|
|
90
|
+
buffer.write_array_of_double(values)
|
|
91
|
+
|
|
92
|
+
data = FFI::MemoryPointer.new(GRNexusData.size)
|
|
93
|
+
struct_instance = GRNexusData.new(data)
|
|
94
|
+
struct_instance[:data] = buffer
|
|
95
|
+
struct_instance[:type] = size == 1 ? 0 : 1 # SCALAR=0, ARRAY=1
|
|
96
|
+
struct_instance[:size] = size
|
|
97
|
+
struct_instance[:stride] = 1
|
|
98
|
+
struct_instance[:dims][0] = size
|
|
99
|
+
struct_instance[:dims][1] = 0
|
|
100
|
+
struct_instance[:dims][2] = 0
|
|
101
|
+
|
|
102
|
+
[data, buffer]
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
# Función para crear GRNexusData para salida
|
|
106
|
+
def self.create_output_grnexus_data(size)
|
|
107
|
+
# Crear un buffer vacío modificable
|
|
108
|
+
buffer = FFI::MemoryPointer.new(:double, size)
|
|
109
|
+
|
|
110
|
+
data = FFI::MemoryPointer.new(GRNexusData.size)
|
|
111
|
+
struct_instance = GRNexusData.new(data)
|
|
112
|
+
struct_instance[:data] = buffer
|
|
113
|
+
struct_instance[:type] = 1 # ARRAY
|
|
114
|
+
struct_instance[:size] = size
|
|
115
|
+
struct_instance[:stride] = 1
|
|
116
|
+
struct_instance[:dims][0] = size
|
|
117
|
+
struct_instance[:dims][1] = 0
|
|
118
|
+
struct_instance[:dims][2] = 0
|
|
119
|
+
|
|
120
|
+
[data, buffer]
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
# Función para leer datos de GRNexusData
|
|
124
|
+
def self.read_grnexus_data(original_ptr, size)
|
|
125
|
+
original_ptr.read_array_of_double(size)
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
class ActivationFunction
|
|
129
|
+
def call(x, derivative: false)
|
|
130
|
+
raise NotImplementedError, "Debes implementar la función de activación"
|
|
131
|
+
end
|
|
132
|
+
end
|
|
133
|
+
|
|
134
|
+
class BaseActivation < ActivationFunction
|
|
135
|
+
protected
|
|
136
|
+
|
|
137
|
+
def execute_activation(func_name, input_values, derivative: false, **kwargs)
|
|
138
|
+
input_data, input_buffer = GRNEXUSActivations.create_grnexus_data(input_values)
|
|
139
|
+
output_size = input_values.is_a?(Array) ? input_values.length : 1
|
|
140
|
+
output_data, output_buffer = GRNEXUSActivations.create_output_grnexus_data(output_size)
|
|
141
|
+
|
|
142
|
+
args = [input_data, output_data, derivative]
|
|
143
|
+
|
|
144
|
+
# Agregar argumentos específicos según la función
|
|
145
|
+
case func_name
|
|
146
|
+
when :LeakyReLU, :ELU, :ThresholdedReLU, :CELU, :HardShrink, :SoftShrink, :FReLU, :Snake, :ISRU, :ISRLU
|
|
147
|
+
args << kwargs[:param1] || 0.01
|
|
148
|
+
when :PReLU
|
|
149
|
+
alpha_ptr = FFI::MemoryPointer.new(:double)
|
|
150
|
+
alpha_ptr.write_double(kwargs[:param1] || 0.01)
|
|
151
|
+
args << alpha_ptr
|
|
152
|
+
when :ARelu
|
|
153
|
+
args << (kwargs[:param1] || 0.01) << (kwargs[:param2] || 1.0)
|
|
154
|
+
when :SnakeBeta
|
|
155
|
+
args << (kwargs[:param1] || 1.0) << (kwargs[:param2] || 1.0)
|
|
156
|
+
when :GLU
|
|
157
|
+
args << (kwargs[:dim] || 0)
|
|
158
|
+
when :Maxout, :Minout
|
|
159
|
+
args << (kwargs[:num_pieces] || 2)
|
|
160
|
+
end
|
|
161
|
+
|
|
162
|
+
# Llamar la función en el módulo GRNEXUSActivations, no en la instancia
|
|
163
|
+
result = GRNEXUSActivations.send(func_name, *args)
|
|
164
|
+
|
|
165
|
+
if result == 0
|
|
166
|
+
output_buffer.read_array_of_double(output_size)
|
|
167
|
+
else
|
|
168
|
+
raise "Función #{func_name} falló con código: #{result}"
|
|
169
|
+
end
|
|
170
|
+
ensure
|
|
171
|
+
# FFI::MemoryPointer se libera automáticamente con GC
|
|
172
|
+
end
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
class Linear < BaseActivation
|
|
176
|
+
def call(x, derivative: false)
|
|
177
|
+
execute_activation(:Linear, x, derivative: derivative)
|
|
178
|
+
end
|
|
179
|
+
end
|
|
180
|
+
|
|
181
|
+
class Step < BaseActivation
|
|
182
|
+
def call(x, derivative: false)
|
|
183
|
+
execute_activation(:Step, x, derivative: derivative)
|
|
184
|
+
end
|
|
185
|
+
end
|
|
186
|
+
|
|
187
|
+
class Sigmoid < BaseActivation
|
|
188
|
+
def call(x, derivative: false)
|
|
189
|
+
execute_activation(:Sigmoid, x, derivative: derivative)
|
|
190
|
+
end
|
|
191
|
+
end
|
|
192
|
+
|
|
193
|
+
class Tanh < BaseActivation
|
|
194
|
+
def call(x, derivative: false)
|
|
195
|
+
execute_activation(:Tanh, x, derivative: derivative)
|
|
196
|
+
end
|
|
197
|
+
end
|
|
198
|
+
|
|
199
|
+
class ReLU < BaseActivation
|
|
200
|
+
def call(x, derivative: false)
|
|
201
|
+
execute_activation(:ReLU, x, derivative: derivative)
|
|
202
|
+
end
|
|
203
|
+
end
|
|
204
|
+
|
|
205
|
+
class LeakyReLU < BaseActivation
|
|
206
|
+
def initialize(alpha: 0.01)
|
|
207
|
+
@alpha = alpha
|
|
208
|
+
end
|
|
209
|
+
|
|
210
|
+
def call(x, derivative: false)
|
|
211
|
+
execute_activation(:LeakyReLU, x, derivative: derivative, param1: @alpha)
|
|
212
|
+
end
|
|
213
|
+
end
|
|
214
|
+
|
|
215
|
+
class PReLU < BaseActivation
|
|
216
|
+
def initialize(alpha: 0.01)
|
|
217
|
+
@alpha = alpha
|
|
218
|
+
end
|
|
219
|
+
|
|
220
|
+
def call(x, derivative: false)
|
|
221
|
+
execute_activation(:PReLU, x, derivative: derivative, param1: @alpha)
|
|
222
|
+
end
|
|
223
|
+
end
|
|
224
|
+
|
|
225
|
+
class ELU < BaseActivation
|
|
226
|
+
def initialize(alpha: 1.0)
|
|
227
|
+
@alpha = alpha
|
|
228
|
+
end
|
|
229
|
+
|
|
230
|
+
def call(x, derivative: false)
|
|
231
|
+
execute_activation(:ELU, x, derivative: derivative, param1: @alpha)
|
|
232
|
+
end
|
|
233
|
+
end
|
|
234
|
+
|
|
235
|
+
class SELU < BaseActivation
|
|
236
|
+
def call(x, derivative: false)
|
|
237
|
+
execute_activation(:SELU, x, derivative: derivative)
|
|
238
|
+
end
|
|
239
|
+
end
|
|
240
|
+
|
|
241
|
+
class Softplus < BaseActivation
|
|
242
|
+
def call(x, derivative: false)
|
|
243
|
+
execute_activation(:Softplus, x, derivative: derivative)
|
|
244
|
+
end
|
|
245
|
+
end
|
|
246
|
+
|
|
247
|
+
class Softsign < BaseActivation
|
|
248
|
+
def call(x, derivative: false)
|
|
249
|
+
execute_activation(:Softsign, x, derivative: derivative)
|
|
250
|
+
end
|
|
251
|
+
end
|
|
252
|
+
|
|
253
|
+
class HardSigmoid < BaseActivation
|
|
254
|
+
def call(x, derivative: false)
|
|
255
|
+
execute_activation(:HardSigmoid, x, derivative: derivative)
|
|
256
|
+
end
|
|
257
|
+
end
|
|
258
|
+
|
|
259
|
+
class HardTanh < BaseActivation
|
|
260
|
+
def call(x, derivative: false)
|
|
261
|
+
execute_activation(:HardTanh, x, derivative: derivative)
|
|
262
|
+
end
|
|
263
|
+
end
|
|
264
|
+
|
|
265
|
+
class ThresholdedReLU < BaseActivation
|
|
266
|
+
def initialize(theta: 1.0)
|
|
267
|
+
@theta = theta
|
|
268
|
+
end
|
|
269
|
+
|
|
270
|
+
def call(x, derivative: false)
|
|
271
|
+
execute_activation(:ThresholdedReLU, x, derivative: derivative, param1: @theta)
|
|
272
|
+
end
|
|
273
|
+
end
|
|
274
|
+
|
|
275
|
+
class GELU < BaseActivation
|
|
276
|
+
def call(x, derivative: false)
|
|
277
|
+
execute_activation(:GELU, x, derivative: derivative)
|
|
278
|
+
end
|
|
279
|
+
end
|
|
280
|
+
|
|
281
|
+
class Swish < BaseActivation
|
|
282
|
+
def call(x, derivative: false)
|
|
283
|
+
execute_activation(:Swish, x, derivative: derivative)
|
|
284
|
+
end
|
|
285
|
+
end
|
|
286
|
+
|
|
287
|
+
class Mish < BaseActivation
|
|
288
|
+
def call(x, derivative: false)
|
|
289
|
+
execute_activation(:Mish, x, derivative: derivative)
|
|
290
|
+
end
|
|
291
|
+
end
|
|
292
|
+
|
|
293
|
+
class LiSHT < BaseActivation
|
|
294
|
+
def call(x, derivative: false)
|
|
295
|
+
execute_activation(:LiSHT, x, derivative: derivative)
|
|
296
|
+
end
|
|
297
|
+
end
|
|
298
|
+
|
|
299
|
+
class ReLUSquared < BaseActivation
|
|
300
|
+
def call(x, derivative: false)
|
|
301
|
+
execute_activation(:ReLUSquared, x, derivative: derivative)
|
|
302
|
+
end
|
|
303
|
+
end
|
|
304
|
+
|
|
305
|
+
class SquaredReLU < BaseActivation
|
|
306
|
+
def call(x, derivative: false)
|
|
307
|
+
execute_activation(:SquaredReLU, x, derivative: derivative)
|
|
308
|
+
end
|
|
309
|
+
end
|
|
310
|
+
|
|
311
|
+
class CELU < BaseActivation
|
|
312
|
+
def initialize(alpha: 1.0)
|
|
313
|
+
@alpha = alpha
|
|
314
|
+
end
|
|
315
|
+
|
|
316
|
+
def call(x, derivative: false)
|
|
317
|
+
execute_activation(:CELU, x, derivative: derivative, param1: @alpha)
|
|
318
|
+
end
|
|
319
|
+
end
|
|
320
|
+
|
|
321
|
+
class HardShrink < BaseActivation
|
|
322
|
+
def initialize(lambda: 0.5)
|
|
323
|
+
@lambda = lambda
|
|
324
|
+
end
|
|
325
|
+
|
|
326
|
+
def call(x, derivative: false)
|
|
327
|
+
execute_activation(:HardShrink, x, derivative: derivative, param1: @lambda)
|
|
328
|
+
end
|
|
329
|
+
end
|
|
330
|
+
|
|
331
|
+
class SoftShrink < BaseActivation
|
|
332
|
+
def initialize(lambda: 0.5)
|
|
333
|
+
@lambda = lambda
|
|
334
|
+
end
|
|
335
|
+
|
|
336
|
+
def call(x, derivative: false)
|
|
337
|
+
execute_activation(:SoftShrink, x, derivative: derivative, param1: @lambda)
|
|
338
|
+
end
|
|
339
|
+
end
|
|
340
|
+
|
|
341
|
+
class TanhShrink < BaseActivation
|
|
342
|
+
def call(x, derivative: false)
|
|
343
|
+
execute_activation(:TanhShrink, x, derivative: derivative)
|
|
344
|
+
end
|
|
345
|
+
end
|
|
346
|
+
|
|
347
|
+
class ReLU6 < BaseActivation
|
|
348
|
+
def call(x, derivative: false)
|
|
349
|
+
execute_activation(:ReLU6, x, derivative: derivative)
|
|
350
|
+
end
|
|
351
|
+
end
|
|
352
|
+
|
|
353
|
+
class HardSwish < BaseActivation
|
|
354
|
+
def call(x, derivative: false)
|
|
355
|
+
execute_activation(:HardSwish, x, derivative: derivative)
|
|
356
|
+
end
|
|
357
|
+
end
|
|
358
|
+
|
|
359
|
+
class SiLU < BaseActivation
|
|
360
|
+
def call(x, derivative: false)
|
|
361
|
+
execute_activation(:SiLU, x, derivative: derivative)
|
|
362
|
+
end
|
|
363
|
+
end
|
|
364
|
+
|
|
365
|
+
class GLU < BaseActivation
|
|
366
|
+
def initialize(dim: 1)
|
|
367
|
+
@dim = dim
|
|
368
|
+
end
|
|
369
|
+
|
|
370
|
+
def call(x, derivative: false)
|
|
371
|
+
execute_activation(:GLU, x, derivative: derivative, dim: @dim)
|
|
372
|
+
end
|
|
373
|
+
end
|
|
374
|
+
|
|
375
|
+
class BReLU < BaseActivation
|
|
376
|
+
def call(x, derivative: false)
|
|
377
|
+
execute_activation(:BReLU, x, derivative: derivative)
|
|
378
|
+
end
|
|
379
|
+
end
|
|
380
|
+
|
|
381
|
+
class ARelu < BaseActivation
|
|
382
|
+
def initialize(alpha: 0.0, beta: 1.0)
|
|
383
|
+
@alpha = alpha
|
|
384
|
+
@beta = beta
|
|
385
|
+
end
|
|
386
|
+
|
|
387
|
+
def call(x, derivative: false)
|
|
388
|
+
execute_activation(:ARelu, x, derivative: derivative, param1: @alpha, param2: @beta)
|
|
389
|
+
end
|
|
390
|
+
end
|
|
391
|
+
|
|
392
|
+
class FReLU < BaseActivation
|
|
393
|
+
def initialize(alpha: 1.0)
|
|
394
|
+
@alpha = alpha
|
|
395
|
+
end
|
|
396
|
+
|
|
397
|
+
def call(x, derivative: false)
|
|
398
|
+
execute_activation(:FReLU, x, derivative: derivative, param1: @alpha)
|
|
399
|
+
end
|
|
400
|
+
end
|
|
401
|
+
|
|
402
|
+
class Snake < BaseActivation
|
|
403
|
+
def initialize(alpha: 1.0)
|
|
404
|
+
@alpha = alpha
|
|
405
|
+
end
|
|
406
|
+
|
|
407
|
+
def call(x, derivative: false)
|
|
408
|
+
execute_activation(:Snake, x, derivative: derivative, param1: @alpha)
|
|
409
|
+
end
|
|
410
|
+
end
|
|
411
|
+
|
|
412
|
+
class SnakeBeta < BaseActivation
|
|
413
|
+
def initialize(alpha: 1.0, beta: 1.0)
|
|
414
|
+
@alpha = alpha
|
|
415
|
+
@beta = beta
|
|
416
|
+
end
|
|
417
|
+
|
|
418
|
+
def call(x, derivative: false)
|
|
419
|
+
execute_activation(:SnakeBeta, x, derivative: derivative, param1: @alpha, param2: @beta)
|
|
420
|
+
end
|
|
421
|
+
end
|
|
422
|
+
|
|
423
|
+
class ISRU < BaseActivation
|
|
424
|
+
def initialize(alpha: 1.0)
|
|
425
|
+
@alpha = alpha
|
|
426
|
+
end
|
|
427
|
+
|
|
428
|
+
def call(x, derivative: false)
|
|
429
|
+
execute_activation(:ISRU, x, derivative: derivative, param1: @alpha)
|
|
430
|
+
end
|
|
431
|
+
end
|
|
432
|
+
|
|
433
|
+
class ISRLU < BaseActivation
|
|
434
|
+
def initialize(alpha: 1.0)
|
|
435
|
+
@alpha = alpha
|
|
436
|
+
end
|
|
437
|
+
|
|
438
|
+
def call(x, derivative: false)
|
|
439
|
+
execute_activation(:ISRLU, x, derivative: derivative, param1: @alpha)
|
|
440
|
+
end
|
|
441
|
+
end
|
|
442
|
+
|
|
443
|
+
class Maxout < BaseActivation
|
|
444
|
+
def initialize(num_pieces: 2)
|
|
445
|
+
@num_pieces = num_pieces
|
|
446
|
+
end
|
|
447
|
+
|
|
448
|
+
def call(x, derivative: false)
|
|
449
|
+
execute_activation(:Maxout, x, derivative: derivative, num_pieces: @num_pieces)
|
|
450
|
+
end
|
|
451
|
+
end
|
|
452
|
+
|
|
453
|
+
class Minout < BaseActivation
|
|
454
|
+
def initialize(num_pieces: 2)
|
|
455
|
+
@num_pieces = num_pieces
|
|
456
|
+
end
|
|
457
|
+
|
|
458
|
+
def call(x, derivative: false)
|
|
459
|
+
execute_activation(:Minout, x, derivative: derivative, num_pieces: @num_pieces)
|
|
460
|
+
end
|
|
461
|
+
end
|
|
462
|
+
end
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
# Callbacks system for training
|
|
2
|
+
module GRNEXUSCallbacks
|
|
3
|
+
class Callback
|
|
4
|
+
def on_train_begin(logs = {})
|
|
5
|
+
end
|
|
6
|
+
|
|
7
|
+
def on_train_end(logs = {})
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
def on_epoch_begin(epoch, logs = {})
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def on_epoch_end(epoch, logs = {})
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def on_batch_begin(batch, logs = {})
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
def on_batch_end(batch, logs = {})
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
class EarlyStopping < Callback
|
|
24
|
+
attr_reader :stopped_epoch
|
|
25
|
+
|
|
26
|
+
def initialize(monitor: 'loss', patience: 10, min_delta: 0.0001, mode: 'min', verbose: true)
|
|
27
|
+
@monitor = monitor
|
|
28
|
+
@patience = patience
|
|
29
|
+
@min_delta = min_delta
|
|
30
|
+
@mode = mode
|
|
31
|
+
@verbose = verbose
|
|
32
|
+
@wait = 0
|
|
33
|
+
@stopped_epoch = 0
|
|
34
|
+
@best = mode == 'min' ? Float::INFINITY : -Float::INFINITY
|
|
35
|
+
@stop_training = false
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
def on_train_begin(logs = {})
|
|
39
|
+
@wait = 0
|
|
40
|
+
@stopped_epoch = 0
|
|
41
|
+
@best = @mode == 'min' ? Float::INFINITY : -Float::INFINITY
|
|
42
|
+
@stop_training = false
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def on_epoch_end(epoch, logs = {})
|
|
46
|
+
current = logs[@monitor.to_sym]
|
|
47
|
+
return unless current
|
|
48
|
+
|
|
49
|
+
if monitor_improved?(current)
|
|
50
|
+
@best = current
|
|
51
|
+
@wait = 0
|
|
52
|
+
else
|
|
53
|
+
@wait += 1
|
|
54
|
+
if @wait >= @patience
|
|
55
|
+
@stopped_epoch = epoch
|
|
56
|
+
@stop_training = true
|
|
57
|
+
puts "Early stopping triggered at epoch #{epoch + 1}" if @verbose
|
|
58
|
+
end
|
|
59
|
+
end
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
def stop_training?
|
|
63
|
+
@stop_training
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
private
|
|
67
|
+
|
|
68
|
+
def monitor_improved?(current)
|
|
69
|
+
if @mode == 'min'
|
|
70
|
+
current < @best - @min_delta
|
|
71
|
+
else
|
|
72
|
+
current > @best + @min_delta
|
|
73
|
+
end
|
|
74
|
+
end
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
class ModelCheckpoint < Callback
|
|
78
|
+
def initialize(filepath:, monitor: 'loss', mode: 'min', save_best_only: true, verbose: true)
|
|
79
|
+
@filepath = filepath
|
|
80
|
+
@monitor = monitor
|
|
81
|
+
@mode = mode
|
|
82
|
+
@save_best_only = save_best_only
|
|
83
|
+
@verbose = verbose
|
|
84
|
+
@best = mode == 'min' ? Float::INFINITY : -Float::INFINITY
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
def on_epoch_end(epoch, logs = {})
|
|
88
|
+
current = logs[@monitor.to_sym]
|
|
89
|
+
return unless current
|
|
90
|
+
|
|
91
|
+
should_save = !@save_best_only || monitor_improved?(current)
|
|
92
|
+
|
|
93
|
+
if should_save
|
|
94
|
+
@best = current if monitor_improved?(current)
|
|
95
|
+
filepath = @filepath.gsub('{epoch}', (epoch + 1).to_s)
|
|
96
|
+
logs[:model].save(filepath)
|
|
97
|
+
puts "Epoch #{epoch + 1}: #{@monitor} improved to #{current.round(6)}, saving model to #{filepath}" if @verbose
|
|
98
|
+
end
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
private
|
|
102
|
+
|
|
103
|
+
def monitor_improved?(current)
|
|
104
|
+
if @mode == 'min'
|
|
105
|
+
current < @best
|
|
106
|
+
else
|
|
107
|
+
current > @best
|
|
108
|
+
end
|
|
109
|
+
end
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
class LearningRateScheduler < Callback
|
|
113
|
+
def initialize(schedule:, verbose: true)
|
|
114
|
+
@schedule = schedule
|
|
115
|
+
@verbose = verbose
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
def on_epoch_begin(epoch, logs = {})
|
|
119
|
+
new_lr = if @schedule.is_a?(Proc)
|
|
120
|
+
@schedule.call(epoch)
|
|
121
|
+
elsif @schedule.is_a?(Hash)
|
|
122
|
+
@schedule[epoch] || logs[:model].learning_rate
|
|
123
|
+
else
|
|
124
|
+
logs[:model].learning_rate
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
logs[:model].learning_rate = new_lr
|
|
128
|
+
puts "Epoch #{epoch + 1}: Learning rate set to #{new_lr}" if @verbose
|
|
129
|
+
end
|
|
130
|
+
end
|
|
131
|
+
|
|
132
|
+
class ReduceLROnPlateau < Callback
|
|
133
|
+
def initialize(monitor: 'loss', factor: 0.5, patience: 5, min_delta: 0.0001,
|
|
134
|
+
min_lr: 1e-7, mode: 'min', verbose: true)
|
|
135
|
+
@monitor = monitor
|
|
136
|
+
@factor = factor
|
|
137
|
+
@patience = patience
|
|
138
|
+
@min_delta = min_delta
|
|
139
|
+
@min_lr = min_lr
|
|
140
|
+
@mode = mode
|
|
141
|
+
@verbose = verbose
|
|
142
|
+
@wait = 0
|
|
143
|
+
@best = mode == 'min' ? Float::INFINITY : -Float::INFINITY
|
|
144
|
+
end
|
|
145
|
+
|
|
146
|
+
def on_train_begin(logs = {})
|
|
147
|
+
@wait = 0
|
|
148
|
+
@best = @mode == 'min' ? Float::INFINITY : -Float::INFINITY
|
|
149
|
+
end
|
|
150
|
+
|
|
151
|
+
def on_epoch_end(epoch, logs = {})
|
|
152
|
+
current = logs[@monitor.to_sym]
|
|
153
|
+
return unless current
|
|
154
|
+
|
|
155
|
+
if monitor_improved?(current)
|
|
156
|
+
@best = current
|
|
157
|
+
@wait = 0
|
|
158
|
+
else
|
|
159
|
+
@wait += 1
|
|
160
|
+
if @wait >= @patience
|
|
161
|
+
old_lr = logs[:model].learning_rate
|
|
162
|
+
new_lr = [old_lr * @factor, @min_lr].max
|
|
163
|
+
|
|
164
|
+
if new_lr < old_lr
|
|
165
|
+
logs[:model].learning_rate = new_lr
|
|
166
|
+
puts "Epoch #{epoch + 1}: Reducing learning rate from #{old_lr} to #{new_lr}" if @verbose
|
|
167
|
+
@wait = 0
|
|
168
|
+
end
|
|
169
|
+
end
|
|
170
|
+
end
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
private
|
|
174
|
+
|
|
175
|
+
def monitor_improved?(current)
|
|
176
|
+
if @mode == 'min'
|
|
177
|
+
current < @best - @min_delta
|
|
178
|
+
else
|
|
179
|
+
current > @best + @min_delta
|
|
180
|
+
end
|
|
181
|
+
end
|
|
182
|
+
end
|
|
183
|
+
|
|
184
|
+
class CSVLogger < Callback
|
|
185
|
+
def initialize(filename:, separator: ',', append: false)
|
|
186
|
+
@filename = filename
|
|
187
|
+
@separator = separator
|
|
188
|
+
@append = append
|
|
189
|
+
@file = nil
|
|
190
|
+
@keys = nil
|
|
191
|
+
end
|
|
192
|
+
|
|
193
|
+
def on_train_begin(logs = {})
|
|
194
|
+
mode = @append ? 'a' : 'w'
|
|
195
|
+
@file = File.open(@filename, mode)
|
|
196
|
+
@keys = nil
|
|
197
|
+
end
|
|
198
|
+
|
|
199
|
+
def on_epoch_end(epoch, logs = {})
|
|
200
|
+
logs_copy = logs.dup
|
|
201
|
+
logs_copy.delete(:model)
|
|
202
|
+
logs_copy[:epoch] = epoch + 1
|
|
203
|
+
|
|
204
|
+
if @keys.nil?
|
|
205
|
+
@keys = logs_copy.keys
|
|
206
|
+
@file.puts @keys.join(@separator)
|
|
207
|
+
end
|
|
208
|
+
|
|
209
|
+
values = @keys.map { |key| logs_copy[key] || '' }
|
|
210
|
+
@file.puts values.join(@separator)
|
|
211
|
+
@file.flush
|
|
212
|
+
end
|
|
213
|
+
|
|
214
|
+
def on_train_end(logs = {})
|
|
215
|
+
@file.close if @file
|
|
216
|
+
end
|
|
217
|
+
end
|
|
218
|
+
|
|
219
|
+
class ProgbarLogger < Callback
|
|
220
|
+
def initialize(count_mode: 'steps')
|
|
221
|
+
@count_mode = count_mode
|
|
222
|
+
end
|
|
223
|
+
|
|
224
|
+
def on_train_begin(logs = {})
|
|
225
|
+
@epochs = logs[:epochs] || 1
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
def on_epoch_begin(epoch, logs = {})
|
|
229
|
+
puts "\nEpoch #{epoch + 1}/#{@epochs}"
|
|
230
|
+
@seen = 0
|
|
231
|
+
@total = logs[:steps] || 0
|
|
232
|
+
end
|
|
233
|
+
|
|
234
|
+
def on_batch_end(batch, logs = {})
|
|
235
|
+
@seen += 1
|
|
236
|
+
progress = (@seen.to_f / @total * 100).round(1)
|
|
237
|
+
bar_length = 30
|
|
238
|
+
filled = (bar_length * @seen / @total).to_i
|
|
239
|
+
bar = '=' * filled + '>' + '.' * (bar_length - filled - 1)
|
|
240
|
+
|
|
241
|
+
print "\r#{@seen}/#{@total} [#{bar}] #{progress}%"
|
|
242
|
+
$stdout.flush
|
|
243
|
+
end
|
|
244
|
+
|
|
245
|
+
def on_epoch_end(epoch, logs = {})
|
|
246
|
+
puts ""
|
|
247
|
+
end
|
|
248
|
+
end
|
|
249
|
+
end
|