llama_cpp 0.6.0 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,8 +13,8 @@ typedef struct {
13
13
 
14
14
  #define QK4_1 32
15
15
  typedef struct {
16
- half d; // delta
17
- half m; // min
16
+ half d; // delta
17
+ half m; // min
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
@@ -132,6 +132,13 @@ kernel void kernel_relu(
132
132
  dst[tpig] = max(0.0f, src0[tpig]);
133
133
  }
134
134
 
135
+ kernel void kernel_sqr(
136
+ device const float * src0,
137
+ device float * dst,
138
+ uint tpig[[thread_position_in_grid]]) {
139
+ dst[tpig] = src0[tpig] * src0[tpig];
140
+ }
141
+
135
142
  constant float GELU_COEF_A = 0.044715f;
136
143
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
137
144
 
@@ -338,10 +345,11 @@ kernel void kernel_rms_norm(
338
345
  uint sgitg[[simdgroup_index_in_threadgroup]],
339
346
  uint tiisg[[thread_index_in_simdgroup]],
340
347
  uint ntg[[threads_per_threadgroup]]) {
341
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
342
- device const float * x_scalar = (device const float *) x;
343
- float4 sumf=0;
344
- float all_sum=0;
348
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
349
+ device const float * x_scalar = (device const float *) x;
350
+
351
+ float4 sumf = 0;
352
+ float all_sum = 0;
345
353
 
346
354
  // parallel sum
347
355
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
@@ -354,6 +362,7 @@ kernel void kernel_rms_norm(
354
362
  }
355
363
 
356
364
  threadgroup_barrier(mem_flags::mem_threadgroup);
365
+
357
366
  // broadcast, simd group number is ntg / 32
358
367
  for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
359
368
  if (tpitg < i) {
@@ -361,7 +370,9 @@ kernel void kernel_rms_norm(
361
370
  }
362
371
  }
363
372
  if (tpitg == 0) {
364
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
373
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {
374
+ sum[0] += x_scalar[i];
375
+ }
365
376
  sum[0] /= ne00;
366
377
  }
367
378
 
@@ -376,7 +387,9 @@ kernel void kernel_rms_norm(
376
387
  y[i00] = x[i00] * scale;
377
388
  }
378
389
  if (tpitg == 0) {
379
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
390
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
391
+ y_scalar[i00] = x_scalar[i00] * scale;
392
+ }
380
393
  }
381
394
  }
382
395
 
@@ -416,8 +429,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
416
429
  }
417
430
 
418
431
  // putting them in the kernel cause a significant performance penalty
419
- #define N_DST 4 // each SIMD group works on 4 rows
420
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
432
+ #define N_DST 4 // each SIMD group works on 4 rows
433
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
421
434
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
422
435
  //Note: This is a template, but strictly speaking it only applies to
423
436
  // quantizations where the block size is 32. It also does not
@@ -428,18 +441,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
428
441
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
429
442
  uint3 tgpig, uint tiisg, uint sgitg) {
430
443
  const int nb = ne00/QK4_0;
444
+
431
445
  const int r0 = tgpig.x;
432
446
  const int r1 = tgpig.y;
433
447
  const int im = tgpig.z;
448
+
434
449
  const int first_row = (r0 * nsg + sgitg) * nr;
450
+
435
451
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
452
+
436
453
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
437
454
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
438
- float yl[16]; // src1 vector cache
439
- float sumf[nr]={0.f};
440
455
 
441
- const int ix = tiisg/2;
442
- const int il = 8*(tiisg%2);
456
+ float yl[16]; // src1 vector cache
457
+ float sumf[nr] = {0.f};
458
+
459
+ const int ix = (tiisg/2);
460
+ const int il = (tiisg%2)*8;
443
461
 
444
462
  device const float * yb = y + ix * QK4_0 + il;
445
463
 
@@ -450,6 +468,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
450
468
  sumy += yb[i] + yb[i+1];
451
469
  yl[i+0] = yb[i+ 0];
452
470
  yl[i+1] = yb[i+ 1]/256.f;
471
+
453
472
  sumy += yb[i+16] + yb[i+17];
454
473
  yl[i+8] = yb[i+16]/16.f;
455
474
  yl[i+9] = yb[i+17]/4096.f;
@@ -465,12 +484,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
465
484
  for (int row = 0; row < nr; ++row) {
466
485
  const float tot = simd_sum(sumf[row]);
467
486
  if (tiisg == 0 && first_row + row < ne01) {
468
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
487
+ dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
469
488
  }
470
489
  }
471
490
  }
472
491
 
473
- kernel void kernel_mul_mat_q4_0_f32(
492
+ kernel void kernel_mul_mv_q4_0_f32(
474
493
  device const void * src0,
475
494
  device const float * src1,
476
495
  device float * dst,
@@ -483,12 +502,12 @@ kernel void kernel_mul_mat_q4_0_f32(
483
502
  constant int64_t & ne1[[buffer(16)]],
484
503
  constant uint & gqa[[buffer(17)]],
485
504
  uint3 tgpig[[threadgroup_position_in_grid]],
486
- uint tiisg[[thread_index_in_simdgroup]],
487
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
505
+ uint tiisg[[thread_index_in_simdgroup]],
506
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
488
507
  mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
489
508
  }
490
509
 
491
- kernel void kernel_mul_mat_q4_1_f32(
510
+ kernel void kernel_mul_mv_q4_1_f32(
492
511
  device const void * src0,
493
512
  device const float * src1,
494
513
  device float * dst,
@@ -508,7 +527,7 @@ kernel void kernel_mul_mat_q4_1_f32(
508
527
 
509
528
  #define NB_Q8_0 8
510
529
 
511
- kernel void kernel_mul_mat_q8_0_f32(
530
+ kernel void kernel_mul_mv_q8_0_f32(
512
531
  device const void * src0,
513
532
  device const float * src1,
514
533
  device float * dst,
@@ -572,7 +591,7 @@ kernel void kernel_mul_mat_q8_0_f32(
572
591
 
573
592
  #define N_F32_F32 4
574
593
 
575
- kernel void kernel_mul_mat_f32_f32(
594
+ kernel void kernel_mul_mv_f32_f32(
576
595
  device const char * src0,
577
596
  device const char * src1,
578
597
  device float * dst,
@@ -643,7 +662,7 @@ kernel void kernel_mul_mat_f32_f32(
643
662
  }
644
663
  }
645
664
 
646
- kernel void kernel_mul_mat_f16_f32_1row(
665
+ kernel void kernel_mul_mv_f16_f32_1row(
647
666
  device const char * src0,
648
667
  device const char * src1,
649
668
  device float * dst,
@@ -662,7 +681,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
662
681
  constant int64_t & ne0,
663
682
  constant int64_t & ne1,
664
683
  uint3 tgpig[[threadgroup_position_in_grid]],
665
- uint tiisg[[thread_index_in_simdgroup]]) {
684
+ uint tiisg[[thread_index_in_simdgroup]]) {
666
685
 
667
686
  const int64_t r0 = tgpig.x;
668
687
  const int64_t r1 = tgpig.y;
@@ -697,7 +716,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
697
716
 
698
717
  #define N_F16_F32 4
699
718
 
700
- kernel void kernel_mul_mat_f16_f32(
719
+ kernel void kernel_mul_mv_f16_f32(
701
720
  device const char * src0,
702
721
  device const char * src1,
703
722
  device float * dst,
@@ -769,7 +788,7 @@ kernel void kernel_mul_mat_f16_f32(
769
788
  }
770
789
 
771
790
  // Assumes row size (ne00) is a multiple of 4
772
- kernel void kernel_mul_mat_f16_f32_l4(
791
+ kernel void kernel_mul_mv_f16_f32_l4(
773
792
  device const char * src0,
774
793
  device const char * src1,
775
794
  device float * dst,
@@ -830,7 +849,9 @@ kernel void kernel_alibi_f32(
830
849
  constant uint64_t & nb1,
831
850
  constant uint64_t & nb2,
832
851
  constant uint64_t & nb3,
833
- constant float & m0,
852
+ constant float & m0,
853
+ constant float & m1,
854
+ constant int & n_heads_log2_floor,
834
855
  uint3 tgpig[[threadgroup_position_in_grid]],
835
856
  uint3 tpitg[[thread_position_in_threadgroup]],
836
857
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -846,7 +867,12 @@ kernel void kernel_alibi_f32(
846
867
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
847
868
 
848
869
  device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
849
- float m_k = pow(m0, i2 + 1);
870
+ float m_k;
871
+ if (i2 < n_heads_log2_floor) {
872
+ m_k = pow(m0, i2 + 1);
873
+ } else {
874
+ m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
875
+ }
850
876
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
851
877
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
852
878
  dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
@@ -1091,6 +1117,62 @@ kernel void kernel_cpy_f32_f32(
1091
1117
  }
1092
1118
  }
1093
1119
 
1120
+ kernel void kernel_concat(
1121
+ device const char * src0,
1122
+ device const char * src1,
1123
+ device char * dst,
1124
+ constant int64_t & ne00,
1125
+ constant int64_t & ne01,
1126
+ constant int64_t & ne02,
1127
+ constant int64_t & ne03,
1128
+ constant uint64_t & nb00,
1129
+ constant uint64_t & nb01,
1130
+ constant uint64_t & nb02,
1131
+ constant uint64_t & nb03,
1132
+ constant int64_t & ne10,
1133
+ constant int64_t & ne11,
1134
+ constant int64_t & ne12,
1135
+ constant int64_t & ne13,
1136
+ constant uint64_t & nb10,
1137
+ constant uint64_t & nb11,
1138
+ constant uint64_t & nb12,
1139
+ constant uint64_t & nb13,
1140
+ constant int64_t & ne0,
1141
+ constant int64_t & ne1,
1142
+ constant int64_t & ne2,
1143
+ constant int64_t & ne3,
1144
+ constant uint64_t & nb0,
1145
+ constant uint64_t & nb1,
1146
+ constant uint64_t & nb2,
1147
+ constant uint64_t & nb3,
1148
+ uint3 tgpig[[threadgroup_position_in_grid]],
1149
+ uint3 tpitg[[thread_position_in_threadgroup]],
1150
+ uint3 ntg[[threads_per_threadgroup]]) {
1151
+
1152
+ const int64_t i03 = tgpig.z;
1153
+ const int64_t i02 = tgpig.y;
1154
+ const int64_t i01 = tgpig.x;
1155
+
1156
+ const int64_t i13 = i03 % ne13;
1157
+ const int64_t i12 = i02 % ne12;
1158
+ const int64_t i11 = i01 % ne11;
1159
+
1160
+ device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1161
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1162
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1163
+
1164
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1165
+ if (i02 < ne02) {
1166
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
1167
+ src0_ptr += ntg.x*nb00;
1168
+ } else {
1169
+ ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
1170
+ src1_ptr += ntg.x*nb10;
1171
+ }
1172
+ dst_ptr += ntg.x*nb0;
1173
+ }
1174
+ }
1175
+
1094
1176
  //============================================ k-quants ======================================================
1095
1177
 
1096
1178
  #ifndef QK_K
@@ -1183,7 +1265,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1183
1265
 
1184
1266
  //====================================== dot products =========================
1185
1267
 
1186
- kernel void kernel_mul_mat_q2_K_f32(
1268
+ kernel void kernel_mul_mv_q2_K_f32(
1187
1269
  device const void * src0,
1188
1270
  device const float * src1,
1189
1271
  device float * dst,
@@ -1327,7 +1409,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1327
1409
  }
1328
1410
 
1329
1411
  #if QK_K == 256
1330
- kernel void kernel_mul_mat_q3_K_f32(
1412
+ kernel void kernel_mul_mv_q3_K_f32(
1331
1413
  device const void * src0,
1332
1414
  device const float * src1,
1333
1415
  device float * dst,
@@ -1479,7 +1561,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1479
1561
  }
1480
1562
  }
1481
1563
  #else
1482
- kernel void kernel_mul_mat_q3_K_f32(
1564
+ kernel void kernel_mul_mv_q3_K_f32(
1483
1565
  device const void * src0,
1484
1566
  device const float * src1,
1485
1567
  device float * dst,
@@ -1550,7 +1632,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1550
1632
  #endif
1551
1633
 
1552
1634
  #if QK_K == 256
1553
- kernel void kernel_mul_mat_q4_K_f32(
1635
+ kernel void kernel_mul_mv_q4_K_f32(
1554
1636
  device const void * src0,
1555
1637
  device const float * src1,
1556
1638
  device float * dst,
@@ -1656,7 +1738,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1656
1738
  }
1657
1739
  }
1658
1740
  #else
1659
- kernel void kernel_mul_mat_q4_K_f32(
1741
+ kernel void kernel_mul_mv_q4_K_f32(
1660
1742
  device const void * src0,
1661
1743
  device const float * src1,
1662
1744
  device float * dst,
@@ -1745,7 +1827,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1745
1827
  }
1746
1828
  #endif
1747
1829
 
1748
- kernel void kernel_mul_mat_q5_K_f32(
1830
+ kernel void kernel_mul_mv_q5_K_f32(
1749
1831
  device const void * src0,
1750
1832
  device const float * src1,
1751
1833
  device float * dst,
@@ -1918,7 +2000,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1918
2000
 
1919
2001
  }
1920
2002
 
1921
- kernel void kernel_mul_mat_q6_K_f32(
2003
+ kernel void kernel_mul_mv_q6_K_f32(
1922
2004
  device const void * src0,
1923
2005
  device const float * src1,
1924
2006
  device float * dst,
@@ -2256,7 +2338,7 @@ kernel void kernel_get_rows(
2256
2338
  }
2257
2339
 
2258
2340
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
2259
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
2341
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
2260
2342
  #define BLOCK_SIZE_K 32
2261
2343
  #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
2262
2344
  #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@@ -2293,9 +2375,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
2293
2375
  const uint r0 = tgpig.y;
2294
2376
  const uint r1 = tgpig.x;
2295
2377
  const uint im = tgpig.z;
2378
+
2296
2379
  // if this block is of 64x32 shape or smaller
2297
2380
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2298
2381
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
2382
+
2299
2383
  // a thread shouldn't load data outside of the matrix
2300
2384
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2301
2385
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2319,26 +2403,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
2319
2403
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2320
2404
 
2321
2405
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2322
- //load data and store to threadgroup memory
2406
+ // load data and store to threadgroup memory
2323
2407
  half4x4 temp_a;
2324
2408
  dequantize_func(x, il, temp_a);
2325
2409
  threadgroup_barrier(mem_flags::mem_threadgroup);
2410
+
2326
2411
  #pragma unroll(16)
2327
2412
  for (int i = 0; i < 16; i++) {
2328
2413
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2329
- + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
2330
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2414
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
2415
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2331
2416
  }
2332
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
2333
- = *((device float2x4 *)y);
2417
+
2418
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
2419
+
2334
2420
  il = (il + 2 < nl) ? il + 2 : il % 2;
2335
2421
  x = (il < 2) ? x + (2+nl-1)/nl : x;
2336
2422
  y += BLOCK_SIZE_K;
2337
2423
 
2338
2424
  threadgroup_barrier(mem_flags::mem_threadgroup);
2339
- //load matrices from threadgroup memory and conduct outer products
2425
+
2426
+ // load matrices from threadgroup memory and conduct outer products
2340
2427
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
2341
2428
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
2429
+
2342
2430
  #pragma unroll(4)
2343
2431
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
2344
2432
  #pragma unroll(4)
@@ -2353,6 +2441,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2353
2441
 
2354
2442
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2355
2443
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
2444
+
2356
2445
  #pragma unroll(8)
2357
2446
  for (int i = 0; i < 8; i++){
2358
2447
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2361,25 +2450,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
2361
2450
  }
2362
2451
 
2363
2452
  if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2364
- device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
2365
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
2453
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2454
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
2366
2455
  for (int i = 0; i < 8; i++) {
2367
2456
  simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
2368
2457
  }
2369
2458
  } else {
2370
2459
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2371
2460
  threadgroup_barrier(mem_flags::mem_threadgroup);
2372
- threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
2461
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
2373
2462
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2374
2463
  for (int i = 0; i < 8; i++) {
2375
2464
  simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2376
2465
  }
2377
2466
 
2378
2467
  threadgroup_barrier(mem_flags::mem_threadgroup);
2379
- device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2380
- if (sgitg==0) {
2468
+
2469
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2470
+ if (sgitg == 0) {
2381
2471
  for (int i = 0; i < n_rows; i++) {
2382
- for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
2472
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
2383
2473
  *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2384
2474
  }
2385
2475
  }