grnexus 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,319 @@
1
+ require 'ffi'
2
+ require 'rbconfig'
3
+
4
+ module GRNEXUSNormalization
5
+ extend FFI::Library
6
+
7
+ # Detectar sistema operativo y cargar biblioteca apropiada
8
+ def self.detect_library
9
+ script_dir = File.dirname(File.expand_path(__FILE__))
10
+ os = RbConfig::CONFIG['host_os']
11
+
12
+ case os
13
+ when /mswin|mingw|cygwin/
14
+ File.join(script_dir, '..', 'exports', 'Windows', 'normalization.dll')
15
+ when /darwin/
16
+ File.join(script_dir, '..', 'exports', 'Mac', 'normalization.dylib')
17
+ when /linux/
18
+ File.join(script_dir, '..', 'exports', 'Linux', 'normalization.so')
19
+ else
20
+ raise "Sistema operativo no soportado: #{os}"
21
+ end
22
+ end
23
+
24
+ # Cargar biblioteca compartida
25
+ ffi_lib detect_library
26
+
27
+ # Definir estructura GRNexusData
28
+ class GRNexusData < FFI::Struct
29
+ layout :data, :pointer,
30
+ :type, :int,
31
+ :size, :size_t,
32
+ :stride, :size_t,
33
+ :dims, [:size_t, 3]
34
+ end
35
+
36
+ # Definir firmas de funciones
37
+ FUNCTIONS = {
38
+ 'Softmax' => [:Softmax, [:pointer, :pointer, :bool, :size_t], :int],
39
+ 'LogSoftmax' => [:LogSoftmax, [:pointer, :pointer, :bool, :size_t], :int],
40
+ 'Sparsemax' => [:Sparsemax, [:pointer, :pointer, :bool, :size_t], :int],
41
+ 'TsallisSoftmax' => [:TsallisSoftmax, [:pointer, :pointer, :bool, :size_t, :double], :int],
42
+ }
43
+
44
+ # Adjuntar funciones a la biblioteca
45
+ FUNCTIONS.each do |name, (func, args, ret)|
46
+ attach_function func, args, ret
47
+ end
48
+
49
+ # Función auxiliar para aplanar arrays y determinar dimensiones
50
+ def self.flatten_array(lst)
51
+ # Si no es un array, retornar como array de un elemento
52
+ return [lst], [] unless lst.is_a?(Array)
53
+
54
+ # Verificar si es un array 1D
55
+ if lst.none? { |item| item.is_a?(Array) }
56
+ return lst, [lst.length]
57
+ end
58
+
59
+ # Verificar si es un array 2D
60
+ if lst.all? { |item| item.is_a?(Array) && item.none? { |x| x.is_a?(Array) } }
61
+ flat = []
62
+ lst.each { |row| flat.concat(row) }
63
+ return flat, [lst.length, lst[0]&.length || 0]
64
+ end
65
+
66
+ # Verificar si es un array 3D
67
+ if lst.all? { |item| item.is_a?(Array) }
68
+ if lst.all? { |item| item.all? { |row| row.is_a?(Array) } }
69
+ flat = []
70
+ dim0 = lst.length
71
+ dim1 = lst[0]&.length || 0
72
+ dim2 = lst[0]&.[](0)&.length || 0
73
+ lst.each do |matrix|
74
+ matrix.each { |row| flat.concat(row) }
75
+ end
76
+ return flat, [dim0, dim1, dim2]
77
+ end
78
+ end
79
+
80
+ # Fallback: aplanar recursivamente
81
+ flat = []
82
+ recursive_flatten = lambda do |l|
83
+ l.each do |item|
84
+ if item.is_a?(Array)
85
+ recursive_flatten.call(item)
86
+ else
87
+ flat << item
88
+ end
89
+ end
90
+ end
91
+ recursive_flatten.call(lst)
92
+ [flat, [flat.length]]
93
+ end
94
+
95
+ # Función para crear GRNexusData
96
+ def self.create_grnexus_data(array_or_scalar)
97
+ values, dims = flatten_array(array_or_scalar)
98
+ size = values.length
99
+
100
+ # Crear buffer con los datos
101
+ buffer = FFI::MemoryPointer.new(:double, size)
102
+ buffer.write_array_of_double(values)
103
+
104
+ data = FFI::MemoryPointer.new(GRNexusData.size)
105
+ struct_instance = GRNexusData.new(data)
106
+ struct_instance[:data] = buffer
107
+ struct_instance[:type] = size == 1 ? 0 : 1 # SCALAR=0, ARRAY=1
108
+ struct_instance[:size] = size
109
+ struct_instance[:stride] = 1
110
+
111
+ # Establecer dimensiones
112
+ if dims.empty?
113
+ struct_instance[:dims][0] = 1
114
+ struct_instance[:dims][1] = 0
115
+ struct_instance[:dims][2] = 0
116
+ elsif dims.length == 1
117
+ struct_instance[:dims][0] = dims[0]
118
+ struct_instance[:dims][1] = 0
119
+ struct_instance[:dims][2] = 0
120
+ elsif dims.length == 2
121
+ struct_instance[:dims][0] = dims[0]
122
+ struct_instance[:dims][1] = dims[1]
123
+ struct_instance[:dims][2] = 0
124
+ elsif dims.length >= 3
125
+ struct_instance[:dims][0] = dims[0]
126
+ struct_instance[:dims][1] = dims[1]
127
+ struct_instance[:dims][2] = dims[2]
128
+ end
129
+
130
+ [data, buffer, dims]
131
+ end
132
+
133
+ # Función para crear GRNexusData de salida
134
+ def self.create_output_grnexus_data(size)
135
+ buffer = FFI::MemoryPointer.new(:double, size)
136
+
137
+ data = FFI::MemoryPointer.new(GRNexusData.size)
138
+ struct_instance = GRNexusData.new(data)
139
+ struct_instance[:data] = buffer
140
+ struct_instance[:type] = 1 # ARRAY
141
+ struct_instance[:size] = size
142
+ struct_instance[:stride] = 1
143
+ struct_instance[:dims][0] = size
144
+ struct_instance[:dims][1] = 0
145
+ struct_instance[:dims][2] = 0
146
+
147
+ [data, buffer]
148
+ end
149
+
150
+ # Función para leer datos de GRNexusData
151
+ def self.read_grnexus_data(buffer, size)
152
+ buffer.read_array_of_double(size)
153
+ end
154
+
155
+ # Clase base de normalización
156
+ class NormalizationLayer
157
+ def normalize(data, derivative: false)
158
+ raise NotImplementedError, "Debes implementar el método de normalización"
159
+ end
160
+
161
+ def call(data, derivative: false)
162
+ normalize(data, derivative: derivative)
163
+ end
164
+ end
165
+
166
+ class BaseNormalization < NormalizationLayer
167
+ protected
168
+
169
+ def execute_normalization(func_name, input_values, derivative: false, **kwargs)
170
+ # Guardar forma original para remodelar la salida
171
+ original_shape = nil
172
+ if input_values.is_a?(Array)
173
+ # Detectar forma
174
+ if input_values[0].is_a?(Array)
175
+ if input_values[0][0].is_a?(Array)
176
+ # Array 3D
177
+ original_shape = [input_values.length, input_values[0].length, input_values[0][0].length]
178
+ else
179
+ # Array 2D
180
+ original_shape = [input_values.length, input_values[0].length]
181
+ end
182
+ end
183
+ end
184
+
185
+ input_data, input_buffer, dims = GRNEXUSNormalization.create_grnexus_data(input_values)
186
+ input_struct = GRNEXUSNormalization::GRNexusData.new(input_data)
187
+ output_size = input_struct[:size]
188
+ output_data, output_buffer = GRNEXUSNormalization.create_output_grnexus_data(output_size)
189
+
190
+ args = [input_data, output_data, derivative]
191
+
192
+ # Agregar parámetro axis
193
+ axis = kwargs[:axis] || output_size # Usar size como axis "global" por defecto
194
+
195
+ # Manejar axis negativo (indexación estilo Python)
196
+ if axis < 0
197
+ if original_shape
198
+ axis = original_shape.length + axis
199
+ else
200
+ axis = 0
201
+ end
202
+ end
203
+
204
+ args << axis
205
+
206
+ # Agregar argumentos específicos de la función
207
+ if func_name == :TsallisSoftmax
208
+ args << (kwargs[:q] || 1.0)
209
+ end
210
+
211
+ # Llamar a la función C
212
+ result = GRNEXUSNormalization.send(func_name, *args)
213
+
214
+ if result != 0
215
+ raise "Función #{func_name} falló con código: #{result}"
216
+ end
217
+
218
+ # Obtener resultado plano
219
+ flat_result = output_buffer.read_array_of_double(output_size)
220
+
221
+ # Remodelar si es necesario
222
+ if original_shape
223
+ if original_shape.length == 2
224
+ # Remodelar a 2D
225
+ reshaped = []
226
+ idx = 0
227
+ original_shape[0].times do
228
+ row = []
229
+ original_shape[1].times do
230
+ row << flat_result[idx]
231
+ idx += 1
232
+ end
233
+ reshaped << row
234
+ end
235
+ return reshaped
236
+ elsif original_shape.length == 3
237
+ # Remodelar a 3D
238
+ reshaped = []
239
+ idx = 0
240
+ original_shape[0].times do
241
+ matrix = []
242
+ original_shape[1].times do
243
+ row = []
244
+ original_shape[2].times do
245
+ row << flat_result[idx]
246
+ idx += 1
247
+ end
248
+ matrix << row
249
+ end
250
+ reshaped << matrix
251
+ end
252
+ return reshaped
253
+ end
254
+ end
255
+
256
+ flat_result
257
+ ensure
258
+ # FFI::MemoryPointer se libera automáticamente con GC
259
+ end
260
+ end
261
+
262
+ # Clases de capas de normalización
263
+ class Softmax < BaseNormalization
264
+ def initialize(axis: -1)
265
+ @axis = axis
266
+ end
267
+
268
+ def normalize(data, derivative: false)
269
+ execute_normalization(:Softmax, data, derivative: derivative, axis: @axis)
270
+ end
271
+
272
+ def call(data, derivative: false)
273
+ normalize(data, derivative: derivative)
274
+ end
275
+ end
276
+
277
+ class LogSoftmax < BaseNormalization
278
+ def initialize(axis: -1)
279
+ @axis = axis
280
+ end
281
+
282
+ def normalize(data, derivative: false)
283
+ execute_normalization(:LogSoftmax, data, derivative: derivative, axis: @axis)
284
+ end
285
+
286
+ def call(data, derivative: false)
287
+ normalize(data, derivative: derivative)
288
+ end
289
+ end
290
+
291
+ class Sparsemax < BaseNormalization
292
+ def initialize(axis: -1)
293
+ @axis = axis
294
+ end
295
+
296
+ def normalize(data, derivative: false)
297
+ execute_normalization(:Sparsemax, data, derivative: derivative, axis: @axis)
298
+ end
299
+
300
+ def call(data, derivative: false)
301
+ normalize(data, derivative: derivative)
302
+ end
303
+ end
304
+
305
+ class TsallisSoftmax < BaseNormalization
306
+ def initialize(axis: -1, q: 1.0)
307
+ @axis = axis
308
+ @q = q
309
+ end
310
+
311
+ def normalize(data, derivative: false)
312
+ execute_normalization(:TsallisSoftmax, data, derivative: derivative, axis: @axis, q: @q)
313
+ end
314
+
315
+ def call(data, derivative: false)
316
+ normalize(data, derivative: derivative)
317
+ end
318
+ end
319
+ end