grnexus 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,722 @@
1
+ require 'ffi'
2
+ require 'rbconfig'
3
+
4
+ module GRNEXUSNumericProcessing
5
+ extend FFI::Library
6
+
7
+ # Detectar sistema operativo
8
+ def self.detect_library
9
+ script_dir = File.dirname(File.expand_path(__FILE__))
10
+ os = RbConfig::CONFIG['host_os']
11
+
12
+ case os
13
+ when /mswin|mingw|cygwin/
14
+ File.join(script_dir, '..', 'exports', 'Windows', 'numeric_proccessing.dll')
15
+ when /darwin/
16
+ File.join(script_dir, '..', 'exports', 'Mac', 'numeric_proccessing.dylib')
17
+ when /linux/
18
+ File.join(script_dir, '..', 'exports', 'Linux', 'numeric_proccessing.so')
19
+ else
20
+ raise "Sistema operativo no soportado: #{os}"
21
+ end
22
+ end
23
+
24
+ # Cargar biblioteca
25
+ ffi_lib detect_library
26
+
27
+ # Definir estructura GRNexusData
28
+ class GRNexusData < FFI::Struct
29
+ layout :data, :pointer,
30
+ :type, :int,
31
+ :size, :size_t,
32
+ :stride, :size_t,
33
+ :dims, [:size_t, 3]
34
+ end
35
+
36
+ # Definir funciones disponibles basadas en numeric_proccessing.c
37
+ FUNCTIONS = {
38
+ # Operaciones binarias entre arrays
39
+ 'SumArray' => [:SumArray, [:pointer, :pointer, :pointer, :bool], :int],
40
+ 'ProductArray' => [:ProductArray, [:pointer, :pointer, :pointer, :bool], :int],
41
+ 'SubtractArray' => [:SubtractArray, [:pointer, :pointer, :pointer, :bool], :int],
42
+ 'DivideArray' => [:DivideArray, [:pointer, :pointer, :pointer, :bool], :int],
43
+ 'PowerArray' => [:PowerArray, [:pointer, :pointer, :pointer, :bool], :int],
44
+ 'ModuloArray' => [:ModuloArray, [:pointer, :pointer, :pointer, :bool], :int],
45
+ 'MaxArray' => [:MaxArray, [:pointer, :pointer, :pointer, :bool], :int],
46
+ 'MinArray' => [:MinArray, [:pointer, :pointer, :pointer, :bool], :int],
47
+
48
+ # Operaciones con escalar
49
+ 'AddScalarToArray' => [:AddScalarToArray, [:pointer, :double, :pointer, :bool], :int],
50
+ 'MultiplyScalarToArray' => [:MultiplyScalarToArray, [:pointer, :double, :pointer, :bool], :int],
51
+
52
+ # Operaciones estadísticas (retornan escalar)
53
+ 'MeanArray' => [:MeanArray, [:pointer, :pointer, :bool], :int],
54
+ 'VarianceArray' => [:VarianceArray, [:pointer, :pointer, :bool], :int],
55
+ 'StdArray' => [:StdArray, [:pointer, :pointer, :bool], :int],
56
+ 'MaxValueArray' => [:MaxValueArray, [:pointer, :pointer, :bool], :int],
57
+ 'MinValueArray' => [:MinValueArray, [:pointer, :pointer, :bool], :int],
58
+ 'SumAllArray' => [:SumAllArray, [:pointer, :pointer, :bool], :int],
59
+ 'ProductAllArray' => [:ProductAllArray, [:pointer, :pointer, :bool], :int],
60
+
61
+ # Operaciones de arrays
62
+ 'ConcatenateArrays' => [:ConcatenateArrays, [:pointer, :pointer, :pointer, :bool], :int],
63
+ 'MovingAverage' => [:MovingAverage, [:pointer, :pointer, :bool, :size_t], :int],
64
+ 'FiniteDifference' => [:FiniteDifference, [:pointer, :pointer, :bool], :int],
65
+ 'DiscreteIntegral' => [:DiscreteIntegral, [:pointer, :pointer, :bool], :int],
66
+ 'ZScoreNormalize' => [:ZScoreNormalize, [:pointer, :pointer, :bool], :int],
67
+ 'MinMaxNormalize' => [:MinMaxNormalize, [:pointer, :pointer, :bool, :double, :double], :int],
68
+
69
+ # Funciones de activación numéricas
70
+ 'LinearNum' => [:LinearNum, [:pointer, :pointer, :bool], :int],
71
+ 'ReLUNum' => [:ReLUNum, [:pointer, :pointer, :bool], :int],
72
+ 'SigmoidNum' => [:SigmoidNum, [:pointer, :pointer, :bool], :int],
73
+ 'TanhNum' => [:TanhNum, [:pointer, :pointer, :bool], :int],
74
+ 'LeakyReLUNum' => [:LeakyReLUNum, [:pointer, :pointer, :bool, :double], :int],
75
+ 'ELUNum' => [:ELUNum, [:pointer, :pointer, :bool, :double], :int],
76
+ 'SoftplusNum' => [:SoftplusNum, [:pointer, :pointer, :bool], :int],
77
+ 'GELUNum' => [:GELUNum, [:pointer, :pointer, :bool], :int],
78
+ 'SwishNum' => [:SwishNum, [:pointer, :pointer, :bool], :int],
79
+ 'MishNum' => [:MishNum, [:pointer, :pointer, :bool], :int],
80
+ 'SiLUNum' => [:SiLUNum, [:pointer, :pointer, :bool], :int],
81
+ 'HardTanhNum' => [:HardTanhNum, [:pointer, :pointer, :bool], :int],
82
+ 'HardSigmoidNum' => [:HardSigmoidNum, [:pointer, :pointer, :bool], :int],
83
+ 'HardSwishNum' => [:HardSwishNum, [:pointer, :pointer, :bool], :int],
84
+ 'SoftsignNum' => [:SoftsignNum, [:pointer, :pointer, :bool], :int],
85
+ 'SELUNum' => [:SELUNum, [:pointer, :pointer, :bool], :int],
86
+ 'CELUNum' => [:CELUNum, [:pointer, :pointer, :bool, :double], :int],
87
+ 'ISRUNum' => [:ISRUNum, [:pointer, :pointer, :bool, :double], :int],
88
+ 'ISRLUNum' => [:ISRLUNum, [:pointer, :pointer, :bool, :double], :int],
89
+ 'ReLU6Num' => [:ReLU6Num, [:pointer, :pointer, :bool], :int],
90
+ 'ThresholdedReLUNum' => [:ThresholdedReLUNum, [:pointer, :pointer, :bool, :double], :int],
91
+ 'PReLUNum' => [:PReLUNum, [:pointer, :pointer, :bool, :pointer], :int],
92
+ 'SquaredReLUNum' => [:SquaredReLUNum, [:pointer, :pointer, :bool], :int],
93
+ 'LiSHTNum' => [:LiSHTNum, [:pointer, :pointer, :bool], :int],
94
+ 'SnakeNum' => [:SnakeNum, [:pointer, :pointer, :bool, :double], :int],
95
+ 'SnakeBetaNum' => [:SnakeBetaNum, [:pointer, :pointer, :bool, :double, :double], :int],
96
+ 'AReluNum' => [:AReluNum, [:pointer, :pointer, :bool, :double, :double], :int],
97
+ 'FReLUNum' => [:FReLUNum, [:pointer, :pointer, :bool, :double], :int],
98
+ 'BReLUNum' => [:BReLUNum, [:pointer, :pointer, :bool], :int],
99
+ 'HardShrinkNum' => [:HardShrinkNum, [:pointer, :pointer, :bool, :double], :int],
100
+ 'SoftShrinkNum' => [:SoftShrinkNum, [:pointer, :pointer, :bool, :double], :int],
101
+ 'TanhShrinkNum' => [:TanhShrinkNum, [:pointer, :pointer, :bool], :int],
102
+ 'MaxoutNum' => [:MaxoutNum, [:pointer, :pointer, :bool, :size_t], :int],
103
+ 'MinoutNum' => [:MinoutNum, [:pointer, :pointer, :bool, :size_t], :int],
104
+ 'GLUNum' => [:GLUNum, [:pointer, :pointer, :bool, :size_t], :int],
105
+ 'ReLUSquaredNum' => [:ReLUSquaredNum, [:pointer, :pointer, :bool], :int],
106
+ }
107
+
108
+ # Adjuntar funciones al módulo
109
+ FUNCTIONS.each do |name, (func, args, ret)|
110
+ attach_function func, args, ret
111
+ end
112
+
113
+ # Función para crear GRNexusData
114
+ def self.create_grnexus_data(array_or_scalar)
115
+ # Convertir a array si es un número
116
+ values = array_or_scalar.is_a?(Array) ? array_or_scalar.flatten : [array_or_scalar]
117
+ size = values.length
118
+
119
+ # Crear un buffer con los datos
120
+ buffer = FFI::MemoryPointer.new(:double, size)
121
+ buffer.write_array_of_double(values)
122
+
123
+ data = FFI::MemoryPointer.new(GRNexusData.size)
124
+ struct_instance = GRNexusData.new(data)
125
+ struct_instance[:data] = buffer
126
+ struct_instance[:type] = size == 1 ? 0 : 1 # SCALAR=0, ARRAY=1
127
+ struct_instance[:size] = size
128
+ struct_instance[:stride] = 1
129
+ struct_instance[:dims][0] = size
130
+ struct_instance[:dims][1] = 0
131
+ struct_instance[:dims][2] = 0
132
+
133
+ [data, buffer]
134
+ end
135
+
136
+ # Función para crear GRNexusData para salida
137
+ def self.create_output_grnexus_data(size)
138
+ # Crear un buffer vacío modificable
139
+ buffer = FFI::MemoryPointer.new(:double, size)
140
+
141
+ data = FFI::MemoryPointer.new(GRNexusData.size)
142
+ struct_instance = GRNexusData.new(data)
143
+ struct_instance[:data] = buffer
144
+ struct_instance[:type] = 1 # ARRAY
145
+ struct_instance[:size] = size
146
+ struct_instance[:stride] = 1
147
+ struct_instance[:dims][0] = size
148
+ struct_instance[:dims][1] = 0
149
+ struct_instance[:dims][2] = 0
150
+
151
+ [data, buffer]
152
+ end
153
+
154
+ # Función para leer datos de GRNexusData
155
+ def self.read_grnexus_data(original_ptr, size)
156
+ original_ptr.read_array_of_double(size)
157
+ end
158
+
159
+ # Clase base para procesadores numéricos
160
+ class NumericProcessor
161
+ def process(*args)
162
+ raise NotImplementedError, "Debes implementar el método de procesamiento numérico"
163
+ end
164
+ end
165
+
166
+ # Clase base para operaciones
167
+ class BaseNumericOperation < NumericProcessor
168
+ protected
169
+
170
+ def execute_binary_operation(func_name, input1, input2, derivative: false)
171
+ input1_data, input1_buffer = GRNEXUSNumericProcessing.create_grnexus_data(input1)
172
+ input2_data, input2_buffer = GRNEXUSNumericProcessing.create_grnexus_data(input2)
173
+
174
+ input1_size = input1.is_a?(Array) ? input1.flatten.length : 1
175
+ input2_size = input2.is_a?(Array) ? input2.flatten.length : 1
176
+
177
+ # Calcular tamaño de salida según la operación
178
+ output_size = if func_name == :ConcatenateArrays
179
+ input1_size + input2_size
180
+ else
181
+ input1_size # Mismo tamaño que input1 para operaciones elemento a elemento
182
+ end
183
+
184
+ output_data, output_buffer = GRNEXUSNumericProcessing.create_output_grnexus_data(output_size)
185
+
186
+ args = [input1_data, input2_data, output_data, derivative]
187
+ result = GRNEXUSNumericProcessing.send(func_name, *args)
188
+
189
+ if result == 0
190
+ output_buffer.read_array_of_double(output_size)
191
+ else
192
+ raise "Función #{func_name} falló con código: #{result}"
193
+ end
194
+ ensure
195
+ # FFI::MemoryPointer se libera automáticamente con GC
196
+ end
197
+
198
+ def execute_unary_operation(func_name, input, derivative: false, **kwargs)
199
+ input_data, input_buffer = GRNEXUSNumericProcessing.create_grnexus_data(input)
200
+ input_size = input.is_a?(Array) ? input.flatten.length : 1
201
+
202
+ # Para operaciones que retornan escalar
203
+ scalar_output = ['MeanArray', 'VarianceArray', 'StdArray', 'MaxValueArray',
204
+ 'MinValueArray', 'SumAllArray', 'ProductAllArray'].include?(func_name.to_s)
205
+
206
+ # Calcular tamaño de salida
207
+ output_size = if scalar_output
208
+ 1
209
+ elsif func_name == :MovingAverage
210
+ window_size = kwargs[:window_size] || 3
211
+ [input_size - window_size + 1, 0].max
212
+ elsif func_name == :FiniteDifference
213
+ [input_size - 1, 0].max
214
+ else
215
+ input_size
216
+ end
217
+
218
+ output_data, output_buffer = GRNEXUSNumericProcessing.create_output_grnexus_data(output_size)
219
+
220
+ args = [input_data, output_data, derivative]
221
+
222
+ # Agregar argumentos adicionales según la función
223
+ case func_name
224
+ when :MovingAverage
225
+ args << (kwargs[:window_size] || 3)
226
+ when :MinMaxNormalize
227
+ args << (kwargs[:min_range] || 0.0) << (kwargs[:max_range] || 1.0)
228
+ when :LeakyReLUNum, :ELUNum, :CELUNum, :ISRUNum, :ISRLUNum, :ThresholdedReLUNum,
229
+ :FReLUNum, :HardShrinkNum, :SoftShrinkNum, :SnakeNum
230
+ args << (kwargs[:param1] || 0.01)
231
+ when :SnakeBetaNum, :AReluNum
232
+ args << (kwargs[:param1] || 1.0) << (kwargs[:param2] || 1.0)
233
+ when :PReLUNum
234
+ alpha_ptr = FFI::MemoryPointer.new(:double)
235
+ alpha_ptr.write_double(kwargs[:param1] || 0.01)
236
+ args << alpha_ptr
237
+ when :MaxoutNum, :MinoutNum, :GLUNum
238
+ args << (kwargs[:num_pieces] || 2)
239
+ end
240
+
241
+ result = GRNEXUSNumericProcessing.send(func_name, *args)
242
+
243
+ if result == 0
244
+ result_array = output_buffer.read_array_of_double(output_size)
245
+ scalar_output ? result_array[0] : result_array
246
+ else
247
+ raise "Función #{func_name} falló con código: #{result}"
248
+ end
249
+ ensure
250
+ # FFI::MemoryPointer se libera automáticamente con GC
251
+ end
252
+
253
+ def execute_scalar_operation(func_name, input, scalar, derivative: false)
254
+ input_data, input_buffer = GRNEXUSNumericProcessing.create_grnexus_data(input)
255
+ input_size = input.is_a?(Array) ? input.flatten.length : 1
256
+ output_data, output_buffer = GRNEXUSNumericProcessing.create_output_grnexus_data(input_size)
257
+
258
+ args = [input_data, scalar, output_data, derivative]
259
+ result = GRNEXUSNumericProcessing.send(func_name, *args)
260
+
261
+ if result == 0
262
+ output_buffer.read_array_of_double(input_size)
263
+ else
264
+ raise "Función #{func_name} falló con código: #{result}"
265
+ end
266
+ ensure
267
+ # FFI::MemoryPointer se libera automáticamente con GC
268
+ end
269
+ end
270
+
271
+ # ============================================================================
272
+ # OPERACIONES BINARIAS ENTRE ARRAYS
273
+ # ============================================================================
274
+
275
+ class SumArray < BaseNumericOperation
276
+ def process(a, b, derivative: false)
277
+ execute_binary_operation(:SumArray, a, b, derivative: derivative)
278
+ end
279
+ end
280
+
281
+ class ProductArray < BaseNumericOperation
282
+ def process(a, b, derivative: false)
283
+ execute_binary_operation(:ProductArray, a, b, derivative: derivative)
284
+ end
285
+ end
286
+
287
+ class SubtractArray < BaseNumericOperation
288
+ def process(a, b, derivative: false)
289
+ execute_binary_operation(:SubtractArray, a, b, derivative: derivative)
290
+ end
291
+ end
292
+
293
+ class DivideArray < BaseNumericOperation
294
+ def process(a, b, derivative: false)
295
+ execute_binary_operation(:DivideArray, a, b, derivative: derivative)
296
+ end
297
+ end
298
+
299
+ class PowerArray < BaseNumericOperation
300
+ def process(a, b, derivative: false)
301
+ execute_binary_operation(:PowerArray, a, b, derivative: derivative)
302
+ end
303
+ end
304
+
305
+ class ModuloArray < BaseNumericOperation
306
+ def process(a, b, derivative: false)
307
+ execute_binary_operation(:ModuloArray, a, b, derivative: derivative)
308
+ end
309
+ end
310
+
311
+ class MaxArray < BaseNumericOperation
312
+ def process(a, b, derivative: false)
313
+ execute_binary_operation(:MaxArray, a, b, derivative: derivative)
314
+ end
315
+ end
316
+
317
+ class MinArray < BaseNumericOperation
318
+ def process(a, b, derivative: false)
319
+ execute_binary_operation(:MinArray, a, b, derivative: derivative)
320
+ end
321
+ end
322
+
323
+ # ============================================================================
324
+ # OPERACIONES CON ESCALAR
325
+ # ============================================================================
326
+
327
+ class AddScalarToArray < BaseNumericOperation
328
+ def process(a, scalar, derivative: false)
329
+ execute_scalar_operation(:AddScalarToArray, a, scalar, derivative: derivative)
330
+ end
331
+ end
332
+
333
+ class MultiplyScalarToArray < BaseNumericOperation
334
+ def process(a, scalar, derivative: false)
335
+ execute_scalar_operation(:MultiplyScalarToArray, a, scalar, derivative: derivative)
336
+ end
337
+ end
338
+
339
+ # ============================================================================
340
+ # OPERACIONES ESTADÍSTICAS (RETORNAN ESCALAR)
341
+ # ============================================================================
342
+
343
+ class MeanArray < BaseNumericOperation
344
+ def process(a, derivative: false)
345
+ execute_unary_operation(:MeanArray, a, derivative: derivative)
346
+ end
347
+ end
348
+
349
+ class VarianceArray < BaseNumericOperation
350
+ def process(a, derivative: false)
351
+ execute_unary_operation(:VarianceArray, a, derivative: derivative)
352
+ end
353
+ end
354
+
355
+ class StdArray < BaseNumericOperation
356
+ def process(a, derivative: false)
357
+ execute_unary_operation(:StdArray, a, derivative: derivative)
358
+ end
359
+ end
360
+
361
+ class MaxValueArray < BaseNumericOperation
362
+ def process(a, derivative: false)
363
+ execute_unary_operation(:MaxValueArray, a, derivative: derivative)
364
+ end
365
+ end
366
+
367
+ class MinValueArray < BaseNumericOperation
368
+ def process(a, derivative: false)
369
+ execute_unary_operation(:MinValueArray, a, derivative: derivative)
370
+ end
371
+ end
372
+
373
+ class SumAllArray < BaseNumericOperation
374
+ def process(a, derivative: false)
375
+ execute_unary_operation(:SumAllArray, a, derivative: derivative)
376
+ end
377
+ end
378
+
379
+ class ProductAllArray < BaseNumericOperation
380
+ def process(a, derivative: false)
381
+ execute_unary_operation(:ProductAllArray, a, derivative: derivative)
382
+ end
383
+ end
384
+
385
+ # ============================================================================
386
+ # OPERACIONES DE ARRAYS
387
+ # ============================================================================
388
+
389
+ class ConcatenateArrays < BaseNumericOperation
390
+ def process(a, b, derivative: false)
391
+ execute_binary_operation(:ConcatenateArrays, a, b, derivative: derivative)
392
+ end
393
+ end
394
+
395
+ class MovingAverage < BaseNumericOperation
396
+ def initialize(window_size: 3)
397
+ @window_size = window_size
398
+ end
399
+
400
+ def process(a, derivative: false)
401
+ execute_unary_operation(:MovingAverage, a, derivative: derivative, window_size: @window_size)
402
+ end
403
+ end
404
+
405
+ class FiniteDifference < BaseNumericOperation
406
+ def process(a, derivative: false)
407
+ execute_unary_operation(:FiniteDifference, a, derivative: derivative)
408
+ end
409
+ end
410
+
411
+ class DiscreteIntegral < BaseNumericOperation
412
+ def process(a, derivative: false)
413
+ execute_unary_operation(:DiscreteIntegral, a, derivative: derivative)
414
+ end
415
+ end
416
+
417
+ class ZScoreNormalize < BaseNumericOperation
418
+ def process(a, derivative: false)
419
+ execute_unary_operation(:ZScoreNormalize, a, derivative: derivative)
420
+ end
421
+ end
422
+
423
+ class MinMaxNormalize < BaseNumericOperation
424
+ def initialize(min_range: 0.0, max_range: 1.0)
425
+ @min_range = min_range
426
+ @max_range = max_range
427
+ end
428
+
429
+ def process(a, derivative: false)
430
+ execute_unary_operation(:MinMaxNormalize, a, derivative: derivative,
431
+ min_range: @min_range, max_range: @max_range)
432
+ end
433
+ end
434
+
435
+ # ============================================================================
436
+ # FUNCIONES DE ACTIVACIÓN NUMÉRICAS
437
+ # ============================================================================
438
+
439
+ class LinearNum < BaseNumericOperation
440
+ def process(a, derivative: false)
441
+ execute_unary_operation(:LinearNum, a, derivative: derivative)
442
+ end
443
+ end
444
+
445
+ class ReLUNum < BaseNumericOperation
446
+ def process(a, derivative: false)
447
+ execute_unary_operation(:ReLUNum, a, derivative: derivative)
448
+ end
449
+ end
450
+
451
+ class SigmoidNum < BaseNumericOperation
452
+ def process(a, derivative: false)
453
+ execute_unary_operation(:SigmoidNum, a, derivative: derivative)
454
+ end
455
+ end
456
+
457
+ class TanhNum < BaseNumericOperation
458
+ def process(a, derivative: false)
459
+ execute_unary_operation(:TanhNum, a, derivative: derivative)
460
+ end
461
+ end
462
+
463
+ class LeakyReLUNum < BaseNumericOperation
464
+ def initialize(alpha: 0.01)
465
+ @alpha = alpha
466
+ end
467
+
468
+ def process(a, derivative: false)
469
+ execute_unary_operation(:LeakyReLUNum, a, derivative: derivative, param1: @alpha)
470
+ end
471
+ end
472
+
473
+ class ELUNum < BaseNumericOperation
474
+ def initialize(alpha: 1.0)
475
+ @alpha = alpha
476
+ end
477
+
478
+ def process(a, derivative: false)
479
+ execute_unary_operation(:ELUNum, a, derivative: derivative, param1: @alpha)
480
+ end
481
+ end
482
+
483
+ class SoftplusNum < BaseNumericOperation
484
+ def process(a, derivative: false)
485
+ execute_unary_operation(:SoftplusNum, a, derivative: derivative)
486
+ end
487
+ end
488
+
489
+ class GELUNum < BaseNumericOperation
490
+ def process(a, derivative: false)
491
+ execute_unary_operation(:GELUNum, a, derivative: derivative)
492
+ end
493
+ end
494
+
495
+ class SwishNum < BaseNumericOperation
496
+ def process(a, derivative: false)
497
+ execute_unary_operation(:SwishNum, a, derivative: derivative)
498
+ end
499
+ end
500
+
501
+ class MishNum < BaseNumericOperation
502
+ def process(a, derivative: false)
503
+ execute_unary_operation(:MishNum, a, derivative: derivative)
504
+ end
505
+ end
506
+
507
+ class SiLUNum < BaseNumericOperation
508
+ def process(a, derivative: false)
509
+ execute_unary_operation(:SiLUNum, a, derivative: derivative)
510
+ end
511
+ end
512
+
513
+ class HardTanhNum < BaseNumericOperation
514
+ def process(a, derivative: false)
515
+ execute_unary_operation(:HardTanhNum, a, derivative: derivative)
516
+ end
517
+ end
518
+
519
+ class HardSigmoidNum < BaseNumericOperation
520
+ def process(a, derivative: false)
521
+ execute_unary_operation(:HardSigmoidNum, a, derivative: derivative)
522
+ end
523
+ end
524
+
525
+ class HardSwishNum < BaseNumericOperation
526
+ def process(a, derivative: false)
527
+ execute_unary_operation(:HardSwishNum, a, derivative: derivative)
528
+ end
529
+ end
530
+
531
+ class SoftsignNum < BaseNumericOperation
532
+ def process(a, derivative: false)
533
+ execute_unary_operation(:SoftsignNum, a, derivative: derivative)
534
+ end
535
+ end
536
+
537
+ class SELUNum < BaseNumericOperation
538
+ def process(a, derivative: false)
539
+ execute_unary_operation(:SELUNum, a, derivative: derivative)
540
+ end
541
+ end
542
+
543
+ class CELUNum < BaseNumericOperation
544
+ def initialize(alpha: 1.0)
545
+ @alpha = alpha
546
+ end
547
+
548
+ def process(a, derivative: false)
549
+ execute_unary_operation(:CELUNum, a, derivative: derivative, param1: @alpha)
550
+ end
551
+ end
552
+
553
+ class ISRUNum < BaseNumericOperation
554
+ def initialize(alpha: 1.0)
555
+ @alpha = alpha
556
+ end
557
+
558
+ def process(a, derivative: false)
559
+ execute_unary_operation(:ISRUNum, a, derivative: derivative, param1: @alpha)
560
+ end
561
+ end
562
+
563
+ class ISRLUNum < BaseNumericOperation
564
+ def initialize(alpha: 1.0)
565
+ @alpha = alpha
566
+ end
567
+
568
+ def process(a, derivative: false)
569
+ execute_unary_operation(:ISRLUNum, a, derivative: derivative, param1: @alpha)
570
+ end
571
+ end
572
+
573
+ class ReLU6Num < BaseNumericOperation
574
+ def process(a, derivative: false)
575
+ execute_unary_operation(:ReLU6Num, a, derivative: derivative)
576
+ end
577
+ end
578
+
579
+ class ThresholdedReLUNum < BaseNumericOperation
580
+ def initialize(theta: 1.0)
581
+ @theta = theta
582
+ end
583
+
584
+ def process(a, derivative: false)
585
+ execute_unary_operation(:ThresholdedReLUNum, a, derivative: derivative, param1: @theta)
586
+ end
587
+ end
588
+
589
+ class PReLUNum < BaseNumericOperation
590
+ def initialize(alpha: 0.01)
591
+ @alpha = alpha
592
+ end
593
+
594
+ def process(a, derivative: false)
595
+ execute_unary_operation(:PReLUNum, a, derivative: derivative, param1: @alpha)
596
+ end
597
+ end
598
+
599
+ class SquaredReLUNum < BaseNumericOperation
600
+ def process(a, derivative: false)
601
+ execute_unary_operation(:SquaredReLUNum, a, derivative: derivative)
602
+ end
603
+ end
604
+
605
+ class LiSHTNum < BaseNumericOperation
606
+ def process(a, derivative: false)
607
+ execute_unary_operation(:LiSHTNum, a, derivative: derivative)
608
+ end
609
+ end
610
+
611
+ class SnakeNum < BaseNumericOperation
612
+ def initialize(alpha: 1.0)
613
+ @alpha = alpha
614
+ end
615
+
616
+ def process(a, derivative: false)
617
+ execute_unary_operation(:SnakeNum, a, derivative: derivative, param1: @alpha)
618
+ end
619
+ end
620
+
621
+ class SnakeBetaNum < BaseNumericOperation
622
+ def initialize(alpha: 1.0, beta: 1.0)
623
+ @alpha = alpha
624
+ @beta = beta
625
+ end
626
+
627
+ def process(a, derivative: false)
628
+ execute_unary_operation(:SnakeBetaNum, a, derivative: derivative,
629
+ param1: @alpha, param2: @beta)
630
+ end
631
+ end
632
+
633
+ class AReluNum < BaseNumericOperation
634
+ def initialize(alpha: 0.0, beta: 1.0)
635
+ @alpha = alpha
636
+ @beta = beta
637
+ end
638
+
639
+ def process(a, derivative: false)
640
+ execute_unary_operation(:AReluNum, a, derivative: derivative,
641
+ param1: @alpha, param2: @beta)
642
+ end
643
+ end
644
+
645
+ class FReLUNum < BaseNumericOperation
646
+ def initialize(alpha: 1.0)
647
+ @alpha = alpha
648
+ end
649
+
650
+ def process(a, derivative: false)
651
+ execute_unary_operation(:FReLUNum, a, derivative: derivative, param1: @alpha)
652
+ end
653
+ end
654
+
655
+ class BReLUNum < BaseNumericOperation
656
+ def process(a, derivative: false)
657
+ execute_unary_operation(:BReLUNum, a, derivative: derivative)
658
+ end
659
+ end
660
+
661
+ class HardShrinkNum < BaseNumericOperation
662
+ def initialize(lambda: 0.5)
663
+ @lambda = lambda
664
+ end
665
+
666
+ def process(a, derivative: false)
667
+ execute_unary_operation(:HardShrinkNum, a, derivative: derivative, param1: @lambda)
668
+ end
669
+ end
670
+
671
+ class SoftShrinkNum < BaseNumericOperation
672
+ def initialize(lambda: 0.5)
673
+ @lambda = lambda
674
+ end
675
+
676
+ def process(a, derivative: false)
677
+ execute_unary_operation(:SoftShrinkNum, a, derivative: derivative, param1: @lambda)
678
+ end
679
+ end
680
+
681
+ class TanhShrinkNum < BaseNumericOperation
682
+ def process(a, derivative: false)
683
+ execute_unary_operation(:TanhShrinkNum, a, derivative: derivative)
684
+ end
685
+ end
686
+
687
+ class MaxoutNum < BaseNumericOperation
688
+ def initialize(num_pieces: 2)
689
+ @num_pieces = num_pieces
690
+ end
691
+
692
+ def process(a, derivative: false)
693
+ execute_unary_operation(:MaxoutNum, a, derivative: derivative, num_pieces: @num_pieces)
694
+ end
695
+ end
696
+
697
+ class MinoutNum < BaseNumericOperation
698
+ def initialize(num_pieces: 2)
699
+ @num_pieces = num_pieces
700
+ end
701
+
702
+ def process(a, derivative: false)
703
+ execute_unary_operation(:MinoutNum, a, derivative: derivative, num_pieces: @num_pieces)
704
+ end
705
+ end
706
+
707
+ class GLUNum < BaseNumericOperation
708
+ def initialize(dim: 1)
709
+ @dim = dim
710
+ end
711
+
712
+ def process(a, derivative: false)
713
+ execute_unary_operation(:GLUNum, a, derivative: derivative, num_pieces: @dim)
714
+ end
715
+ end
716
+
717
+ class ReLUSquaredNum < BaseNumericOperation
718
+ def process(a, derivative: false)
719
+ execute_unary_operation(:ReLUSquaredNum, a, derivative: derivative)
720
+ end
721
+ end
722
+ end