tensor_stream 0.2.0 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (58) hide show
  1. checksums.yaml +5 -5
  2. data/.circleci/config.yml +2 -1
  3. data/CHANGELOG.md +5 -0
  4. data/README.md +28 -1
  5. data/benchmark/benchmark.rb +129 -0
  6. data/lib/tensor_stream.rb +7 -4
  7. data/lib/tensor_stream/evaluator/buffer.rb +10 -0
  8. data/lib/tensor_stream/evaluator/evaluator.rb +1 -0
  9. data/lib/tensor_stream/evaluator/kernels/_bool_operand.cl +45 -0
  10. data/lib/tensor_stream/evaluator/kernels/_operand.cl +45 -0
  11. data/lib/tensor_stream/evaluator/kernels/abs.cl +16 -0
  12. data/lib/tensor_stream/evaluator/kernels/add.cl +5 -0
  13. data/lib/tensor_stream/evaluator/kernels/argmax.cl +15 -0
  14. data/lib/tensor_stream/evaluator/kernels/argmin.cl +15 -0
  15. data/lib/tensor_stream/evaluator/kernels/cast.cl +15 -0
  16. data/lib/tensor_stream/evaluator/kernels/cond.cl.erb +5 -0
  17. data/lib/tensor_stream/evaluator/kernels/cos.cl +7 -0
  18. data/lib/tensor_stream/evaluator/kernels/div.cl.erb +5 -0
  19. data/lib/tensor_stream/evaluator/kernels/exp.cl +7 -0
  20. data/lib/tensor_stream/evaluator/kernels/gemm.cl +63 -0
  21. data/lib/tensor_stream/evaluator/kernels/log.cl +7 -0
  22. data/lib/tensor_stream/evaluator/kernels/log1p.cl +7 -0
  23. data/lib/tensor_stream/evaluator/kernels/max.cl +91 -0
  24. data/lib/tensor_stream/evaluator/kernels/mul.cl +5 -0
  25. data/lib/tensor_stream/evaluator/kernels/negate.cl +15 -0
  26. data/lib/tensor_stream/evaluator/kernels/pow.cl +130 -0
  27. data/lib/tensor_stream/evaluator/kernels/reciprocal.cl +15 -0
  28. data/lib/tensor_stream/evaluator/kernels/round.cl +7 -0
  29. data/lib/tensor_stream/evaluator/kernels/sigmoid.cl +8 -0
  30. data/lib/tensor_stream/evaluator/kernels/sigmoid_grad.cl +54 -0
  31. data/lib/tensor_stream/evaluator/kernels/sign.cl +23 -0
  32. data/lib/tensor_stream/evaluator/kernels/sin.cl +8 -0
  33. data/lib/tensor_stream/evaluator/kernels/sqrt.cl +8 -0
  34. data/lib/tensor_stream/evaluator/kernels/square.cl +15 -0
  35. data/lib/tensor_stream/evaluator/kernels/sub.cl +5 -0
  36. data/lib/tensor_stream/evaluator/kernels/tan.cl +7 -0
  37. data/lib/tensor_stream/evaluator/kernels/tanh.cl +7 -0
  38. data/lib/tensor_stream/evaluator/kernels/tanh_grad.cl +6 -0
  39. data/lib/tensor_stream/evaluator/kernels/where.cl +15 -0
  40. data/lib/tensor_stream/evaluator/opencl_buffer.rb +30 -0
  41. data/lib/tensor_stream/evaluator/opencl_evaluator.rb +1095 -0
  42. data/lib/tensor_stream/evaluator/opencl_template_helper.rb +58 -0
  43. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +27 -0
  44. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +20 -31
  45. data/lib/tensor_stream/graph.rb +4 -2
  46. data/lib/tensor_stream/math_gradients.rb +3 -0
  47. data/lib/tensor_stream/operation.rb +29 -2
  48. data/lib/tensor_stream/ops.rb +14 -2
  49. data/lib/tensor_stream/placeholder.rb +1 -1
  50. data/lib/tensor_stream/session.rb +10 -3
  51. data/lib/tensor_stream/tensor_shape.rb +1 -1
  52. data/lib/tensor_stream/train/saver.rb +1 -1
  53. data/lib/tensor_stream/variable.rb +7 -1
  54. data/lib/tensor_stream/version.rb +1 -1
  55. data/samples/logistic_regression.rb +2 -1
  56. data/samples/nearest_neighbor.rb +54 -0
  57. data/tensor_stream.gemspec +3 -1
  58. metadata +107 -28
@@ -0,0 +1,7 @@
1
+ __kernel void cos_fp(const int M, const int N, __global const float *A, __global float *C) {
2
+ // Get the index of the current element to be processed
3
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
4
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
5
+
6
+ C[globalRow * N + globalCol] = cos(A[globalRow * N + globalCol]);
7
+ }
@@ -0,0 +1,5 @@
1
+ % %w[fp int].product(%w[div]).each do |dtype, fname|
2
+ % c_dtype = dtype_to_c_type(dtype)
3
+ % op = operator_to_c(fname)
4
+ <%= render 'operand.cl', c_dtype: c_dtype, op: op, fname: fname, dtype: dtype, result_t: c_dtype %>
5
+ % end
@@ -0,0 +1,7 @@
1
+ __kernel void exp_fp(const int M, const int N, __global const float *A, __global float *C) {
2
+ // Get the index of the current element to be processed
3
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
4
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
5
+
6
+ C[globalRow * N + globalCol] = exp(A[globalRow * N + globalCol]);
7
+ }
@@ -0,0 +1,63 @@
1
+ // First naive implementation
2
+ __kernel void gemm_fp(const int M, const int N, const int K,
3
+ const int A_transpose,
4
+ const int B_transpose,
5
+ const __global float* A,
6
+ const __global float* B,
7
+ __global float* C) {
8
+
9
+ // Get the index of the current element to be processed
10
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
11
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
12
+
13
+ // Compute a single element (loop over K)
14
+ float acc = 0.0f;
15
+ for (int k=0; k<K; k++) {
16
+ int a_index = globalRow*K + k;
17
+ int b_index = k*N + globalCol;
18
+
19
+ if (A_transpose) {
20
+ a_index = M*k + globalRow;
21
+ }
22
+
23
+ if (B_transpose) {
24
+ b_index = globalCol*K + k;
25
+ }
26
+ acc += A[a_index] * B[b_index];
27
+ }
28
+
29
+ // Store the result
30
+ C[globalRow*N + globalCol] = acc;
31
+ }
32
+
33
+ // First naive implementation
34
+ __kernel void gemm_int(const int M, const int N, const int K,
35
+ const int A_transpose,
36
+ const int B_transpose,
37
+ const __global int* A,
38
+ const __global int* B,
39
+ __global int* C) {
40
+
41
+ // Get the index of the current element to be processed
42
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
43
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
44
+
45
+ // Compute a single element (loop over K)
46
+ int acc = 0;
47
+ for (int k=0; k<K; k++) {
48
+ int a_index = globalRow*K + k;
49
+ int b_index = k*N + globalCol;
50
+
51
+ if (A_transpose) {
52
+ a_index = M*k + globalRow;
53
+ }
54
+
55
+ if (B_transpose) {
56
+ b_index = globalCol*K + k;
57
+ }
58
+ acc += A[a_index] * B[b_index];
59
+ }
60
+
61
+ // Store the result
62
+ C[globalRow*N + globalCol] = acc;
63
+ }
@@ -0,0 +1,7 @@
1
+ __kernel void log_fp(const int M, const int N, __global const float *A, __global float *C) {
2
+ // Get the index of the current element to be processed
3
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
4
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
5
+
6
+ C[globalRow * N + globalCol] = log(A[globalRow * N + globalCol]);
7
+ }
@@ -0,0 +1,7 @@
1
+ __kernel void log1p_fp(const int M, const int N, __global const float *A, __global float *C) {
2
+ // Get the index of the current element to be processed
3
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
4
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
5
+
6
+ C[globalRow * N + globalCol] = log1p(A[globalRow * N + globalCol]);
7
+ }
@@ -0,0 +1,91 @@
1
+ // same dimension add floating point op
2
+ __kernel void max_fp(const int M, const int N, const int switch_op, __global const float *A, __global const float *B, __global float *C) {
3
+ // Get the index of the current element to be processed
4
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
5
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
6
+
7
+ C[globalRow * N + globalCol] = A[globalRow * N + globalCol] > B[globalRow * N + globalCol] ? A[globalRow * N + globalCol] : B[globalRow * N + globalCol];
8
+ }
9
+
10
+ // 1D + Scalar floating point add op
11
+ __kernel void max_c_fp(const int M, const int N, const int switch_op, __global const float *A, __global const float *B, __global float *C) {
12
+ // Get the index of the current element to be processed
13
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
14
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
15
+
16
+ if (switch_op == 0) {
17
+ C[globalRow * N + globalCol] = A[globalRow * N + globalCol] > B[0] ? A[globalRow * N + globalCol] : B[0];
18
+ } else {
19
+ C[globalRow * N + globalCol] = B[0] > A[globalRow * N + globalCol] ? B[0] : A[globalRow * N + globalCol];
20
+ }
21
+ }
22
+
23
+ // 1D + Scalar floating point add op broadcast
24
+ __kernel void max_b_fp(const int M, const int N, const int M2, const int N2, const int switch_op, __global const float *A, __global const float *B, __global float *C) {
25
+ // Get the index of the current element to be processed
26
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
27
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
28
+
29
+ int b_m_index = globalRow;
30
+ int b_n_index = globalCol;
31
+
32
+ if ( b_m_index >= M2) {
33
+ b_m_index = b_m_index % M2;
34
+ };
35
+
36
+ if (b_n_index >= N2) {
37
+ b_n_index = b_n_index % N2;
38
+ }
39
+
40
+ if (switch_op == 0) {
41
+ C[globalRow * N + globalCol] = A[globalRow * N + globalCol] > B[b_m_index * N2 + b_n_index] ? A[globalRow * N + globalCol] : B[b_m_index * N2 + b_n_index];
42
+ } else {
43
+ C[globalRow * N + globalCol] = B[b_m_index * N2 + b_n_index] > A[globalRow * N + globalCol] ? B[b_m_index * N2 + b_n_index] : A[globalRow * N + globalCol];
44
+ }
45
+ }
46
+
47
+ // same dimension add floating point op
48
+ __kernel void max_int(const int M, const int N, const int switch_op, __global const int *A, __global const int *B, __global int *C) {
49
+ // Get the index of the current element to be processed
50
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
51
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
52
+
53
+ C[globalRow * N + globalCol] = A[globalRow * N + globalCol] > B[globalRow * N + globalCol] ? A[globalRow * N + globalCol] : B[globalRow * N + globalCol];
54
+ }
55
+
56
+ // 1D + Scalar floating point add op
57
+ __kernel void max_c_int(const int M, const int N, const int switch_op, __global const int *A, __global const int *B, __global int *C) {
58
+ // Get the index of the current element to be processed
59
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
60
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
61
+
62
+ if (switch_op == 0) {
63
+ C[globalRow * N + globalCol] = A[globalRow * N + globalCol] > B[0] ? A[globalRow * N + globalCol] : B[0];
64
+ } else {
65
+ C[globalRow * N + globalCol] = B[0] > A[globalRow * N + globalCol] ? B[0] : A[globalRow * N + globalCol];
66
+ }
67
+ }
68
+
69
+ // 1D + Scalar floating point add op broadcast
70
+ __kernel void max_b_int(const int M, const int N, const int M2, const int N2, const int switch_op, __global const int *A, __global const int *B, __global int *C) {
71
+ // Get the index of the current element to be processed
72
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
73
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
74
+
75
+ int b_m_index = globalRow;
76
+ int b_n_index = globalCol;
77
+
78
+ if ( b_m_index >= M2) {
79
+ b_m_index = b_m_index % M2;
80
+ };
81
+
82
+ if (b_n_index >= N2) {
83
+ b_n_index = b_n_index % N2;
84
+ }
85
+
86
+ if (switch_op == 0) {
87
+ C[globalRow * N + globalCol] = A[globalRow * N + globalCol] > B[b_m_index * N2 + b_n_index] ? A[globalRow * N + globalCol] : B[b_m_index * N2 + b_n_index];
88
+ } else {
89
+ C[globalRow * N + globalCol] = B[b_m_index * N2 + b_n_index] > A[globalRow * N + globalCol] ? B[b_m_index * N2 + b_n_index] : A[globalRow * N + globalCol];
90
+ }
91
+ }
@@ -0,0 +1,5 @@
1
+ % %w[fp int].product(%w[mul]).each do |dtype, fname|
2
+ % c_dtype = dtype_to_c_type(dtype)
3
+ % op = operator_to_c(fname)
4
+ <%= render 'operand.cl', c_dtype: c_dtype, op: op, fname: fname, dtype: dtype, result_t: c_dtype %>
5
+ % end
@@ -0,0 +1,15 @@
1
+ __kernel void negate_fp(const int M, const int N, __global const float *A, __global float *C) {
2
+ // Get the index of the current element to be processed
3
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
4
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
5
+
6
+ C[globalRow * N + globalCol] = -A[globalRow * N + globalCol];
7
+ }
8
+
9
+ __kernel void negate_int(const int M, const int N, __global const int *A, __global int *C) {
10
+ // Get the index of the current element to be processed
11
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
12
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
13
+
14
+ C[globalRow * N + globalCol] = -A[globalRow * N + globalCol];
15
+ }
@@ -0,0 +1,130 @@
1
+ // same dimension add floating point op
2
+ __kernel void pow_fp(const int M, const int N, const int switch_op, __global const float *A, __global const float *B, __global float *C) {
3
+ // Get the index of the current element to be processed
4
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
5
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
6
+
7
+ C[globalRow * N + globalCol] = pow((float)A[globalRow * N + globalCol], (float)B[globalRow * N + globalCol]);
8
+ }
9
+
10
+ // 1D + Scalar floating point add op
11
+ __kernel void pow_c_fp(const int M, const int N, const int switch_op, __global const float *A, __global const float *B, __global float *C) {
12
+ // Get the index of the current element to be processed
13
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
14
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
15
+
16
+ if (switch_op == 0) {
17
+ C[globalRow * N + globalCol] = pow((float)A[globalRow * N + globalCol], (float)B[0]);
18
+ } else {
19
+ C[globalRow * N + globalCol] = pow((float)B[0], (float)A[globalRow * N + globalCol]);
20
+ }
21
+ }
22
+
23
+ // 1D + Scalar floating point add op broadcast
24
+ __kernel void pow_b_fp(const int M, const int N, const int M2, const int N2, const int switch_op, __global const float *A, __global const float *B, __global float *C) {
25
+ // Get the index of the current element to be processed
26
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
27
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
28
+
29
+ int b_m_index = globalRow;
30
+ int b_n_index = globalCol;
31
+
32
+ if ( b_m_index >= M2) {
33
+ b_m_index = b_m_index % M2;
34
+ };
35
+
36
+ if (b_n_index >= N2) {
37
+ b_n_index = b_n_index % N2;
38
+ }
39
+
40
+ if (switch_op == 0) {
41
+ C[globalRow * N + globalCol] = pow((float)A[globalRow * N + globalCol], (float)B[b_m_index * N2 + b_n_index]);
42
+ } else {
43
+ C[globalRow * N + globalCol] = pow((float)B[b_m_index * N2 + b_n_index], (float)A[globalRow * N + globalCol]);
44
+ }
45
+ }
46
+
47
+ // same dimension add floating point op
48
+ __kernel void pow_int(const int M, const int N, const int switch_op, __global const int *A, __global const int *B, __global int *C) {
49
+ // Get the index of the current element to be processed
50
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
51
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
52
+
53
+ int acc = A[globalRow * N + globalCol];
54
+ const int count = B[globalRow * N + globalCol];
55
+ const int c = A[globalRow * N + globalCol];
56
+
57
+ if (count < 4) {
58
+ for(int i = 0; i < count - 1; i++) {
59
+ acc *= c;
60
+ }
61
+ C[globalRow * N + globalCol] = acc;
62
+ } else {
63
+ C[globalRow * N + globalCol] = pow((float)c, (float)count);
64
+ }
65
+ }
66
+
67
+ // 1D + Scalar floating point add op
68
+ __kernel void pow_c_int(const int M, const int N, const int switch_op, __global const int *A, __global const int *B, __global int *C) {
69
+ // Get the index of the current element to be processed
70
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
71
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
72
+
73
+ int acc, count, c;
74
+ if (switch_op == 0) {
75
+ acc = A[globalRow * N + globalCol];
76
+ count = B[0];
77
+ c = A[globalRow * N + globalCol];
78
+ } else {
79
+ acc = B[0];
80
+ count = A[globalRow * N + globalCol];
81
+ c = B[0];
82
+ }
83
+ if (count < 4) {
84
+ for(int i =0; i < count - 1; i++) {
85
+ acc *= c;
86
+ }
87
+ C[globalRow * N + globalCol] = acc;
88
+ } else {
89
+ C[globalRow * N + globalCol] = pow((float)c, (float)count);
90
+ }
91
+ }
92
+
93
+ // 1D + Scalar floating point add op broadcast
94
+ __kernel void pow_b_int(const int M, const int N, const int M2, const int N2, const int switch_op, __global const int *A, __global const int *B, __global int *C) {
95
+ // Get the index of the current element to be processed
96
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
97
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
98
+
99
+ int b_m_index = globalRow;
100
+ int b_n_index = globalCol;
101
+
102
+ if ( b_m_index >= M2) {
103
+ b_m_index = b_m_index % M2;
104
+ };
105
+
106
+ if (b_n_index >= N2) {
107
+ b_n_index = b_n_index % N2;
108
+ }
109
+
110
+ int acc, count, c;
111
+
112
+ if (switch_op == 0) {
113
+ acc = A[globalRow * N + globalCol];
114
+ count = B[b_m_index * N2 + b_n_index];
115
+ c = A[globalRow * N + globalCol];
116
+ } else {
117
+ acc = B[b_m_index * N2 + b_n_index];
118
+ count = A[globalRow * N + globalCol];
119
+ c = B[b_m_index * N2 + b_n_index];
120
+ }
121
+
122
+ if (count < 4) {
123
+ for (int i = 0; i < count - 1; i++) {
124
+ acc *= c;
125
+ }
126
+ C[globalRow * N + globalCol] = acc;
127
+ } else {
128
+ C[globalRow * N + globalCol] = pow((float)c, (float)count);
129
+ }
130
+ }
@@ -0,0 +1,15 @@
1
+ __kernel void reciprocal_fp(const int M, const int N, __global const float *A, __global float *C) {
2
+ // Get the index of the current element to be processed
3
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
4
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
5
+
6
+ C[globalRow * N + globalCol] = 1.0f / A[globalRow * N + globalCol];
7
+ }
8
+
9
+ __kernel void reciprocal_int(const int M, const int N, __global const int *A, __global int *C) {
10
+ // Get the index of the current element to be processed
11
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
12
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
13
+
14
+ C[globalRow * N + globalCol] = 1 / A[globalRow * N + globalCol];
15
+ }
@@ -0,0 +1,7 @@
1
+ __kernel void round_fp(const int M, const int N, __global const float *A, __global float *C) {
2
+ // Get the index of the current element to be processed
3
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
4
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
5
+
6
+ C[globalRow * N + globalCol] = round(A[globalRow * N + globalCol]);
7
+ }
@@ -0,0 +1,8 @@
1
+
2
+ __kernel void sigmoid_fp(const int M, const int N, __global const float *A, __global float *C) {
3
+ // Get the index of the current element to be processed
4
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
5
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
6
+
7
+ C[globalRow * N + globalCol] = 1.0f/(1.0f + exp(-A[globalRow * N + globalCol]));
8
+ }
@@ -0,0 +1,54 @@
1
+
2
+ float sigmoid(float x) {
3
+ return 1.0f/(1.0f + exp(-x));
4
+ }
5
+
6
+ float sigmoid_grad(float x, float g) {
7
+ return g * sigmoid(x) * ( 1.0f - sigmoid(x));
8
+ }
9
+
10
+ // same dimension add floating point op
11
+ __kernel void sigmoid_grad_fp(const int M, const int N, const int switch_op, __global const float *A, __global const float *B, __global float *C) {
12
+ // Get the index of the current element to be processed
13
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
14
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
15
+
16
+ C[globalRow * N + globalCol] = sigmoid_grad(A[globalRow * N + globalCol], B[globalRow * N + globalCol]);
17
+ }
18
+
19
+ // 1D + Scalar floating point add op
20
+ __kernel void sigmoid_grad_c_fp(const int M, const int N, const int switch_op, __global const float *A, __global const float *B, __global float *C) {
21
+ // Get the index of the current element to be processed
22
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
23
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
24
+
25
+ if (switch_op == 0) {
26
+ C[globalRow * N + globalCol] = sigmoid_grad(A[globalRow * N + globalCol], B[0]);
27
+ } else {
28
+ C[globalRow * N + globalCol] = sigmoid_grad(B[0], A[globalRow * N + globalCol]);
29
+ }
30
+ }
31
+
32
+ // 1D + Scalar floating point add op broadcast
33
+ __kernel void sigmoid_grad_b_fp(const int M, const int N, const int M2, const int N2, const int switch_op, __global const float *A, __global const float *B, __global float *C) {
34
+ // Get the index of the current element to be processed
35
+ const int globalRow = get_global_id(0); // Row ID of C (0..M)
36
+ const int globalCol = get_global_id(1); // Col ID of C (0..N)
37
+
38
+ int b_m_index = globalRow;
39
+ int b_n_index = globalCol;
40
+
41
+ if ( b_m_index >= M2) {
42
+ b_m_index = b_m_index % M2;
43
+ };
44
+
45
+ if (b_n_index >= N2) {
46
+ b_n_index = b_n_index % N2;
47
+ }
48
+
49
+ if (switch_op == 0) {
50
+ C[globalRow * N + globalCol] = sigmoid_grad(A[globalRow * N + globalCol], B[b_m_index * N2 + b_n_index]);
51
+ } else {
52
+ C[globalRow * N + globalCol] = sigmoid_grad(B[b_m_index * N2 + b_n_index], A[globalRow * N + globalCol]);
53
+ }
54
+ }