llama_cpp 0.6.0 → 0.7.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -13,8 +13,8 @@ typedef struct {
13
13
 
14
14
  #define QK4_1 32
15
15
  typedef struct {
16
- half d; // delta
17
- half m; // min
16
+ half d; // delta
17
+ half m; // min
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
@@ -132,6 +132,13 @@ kernel void kernel_relu(
132
132
  dst[tpig] = max(0.0f, src0[tpig]);
133
133
  }
134
134
 
135
+ kernel void kernel_sqr(
136
+ device const float * src0,
137
+ device float * dst,
138
+ uint tpig[[thread_position_in_grid]]) {
139
+ dst[tpig] = src0[tpig] * src0[tpig];
140
+ }
141
+
135
142
  constant float GELU_COEF_A = 0.044715f;
136
143
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
137
144
 
@@ -338,10 +345,11 @@ kernel void kernel_rms_norm(
338
345
  uint sgitg[[simdgroup_index_in_threadgroup]],
339
346
  uint tiisg[[thread_index_in_simdgroup]],
340
347
  uint ntg[[threads_per_threadgroup]]) {
341
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
342
- device const float * x_scalar = (device const float *) x;
343
- float4 sumf=0;
344
- float all_sum=0;
348
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
349
+ device const float * x_scalar = (device const float *) x;
350
+
351
+ float4 sumf = 0;
352
+ float all_sum = 0;
345
353
 
346
354
  // parallel sum
347
355
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
@@ -354,6 +362,7 @@ kernel void kernel_rms_norm(
354
362
  }
355
363
 
356
364
  threadgroup_barrier(mem_flags::mem_threadgroup);
365
+
357
366
  // broadcast, simd group number is ntg / 32
358
367
  for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
359
368
  if (tpitg < i) {
@@ -361,7 +370,9 @@ kernel void kernel_rms_norm(
361
370
  }
362
371
  }
363
372
  if (tpitg == 0) {
364
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
373
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {
374
+ sum[0] += x_scalar[i];
375
+ }
365
376
  sum[0] /= ne00;
366
377
  }
367
378
 
@@ -376,7 +387,9 @@ kernel void kernel_rms_norm(
376
387
  y[i00] = x[i00] * scale;
377
388
  }
378
389
  if (tpitg == 0) {
379
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
390
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
391
+ y_scalar[i00] = x_scalar[i00] * scale;
392
+ }
380
393
  }
381
394
  }
382
395
 
@@ -416,8 +429,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
416
429
  }
417
430
 
418
431
  // putting them in the kernel cause a significant performance penalty
419
- #define N_DST 4 // each SIMD group works on 4 rows
420
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
432
+ #define N_DST 4 // each SIMD group works on 4 rows
433
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
421
434
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
422
435
  //Note: This is a template, but strictly speaking it only applies to
423
436
  // quantizations where the block size is 32. It also does not
@@ -428,18 +441,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
428
441
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
429
442
  uint3 tgpig, uint tiisg, uint sgitg) {
430
443
  const int nb = ne00/QK4_0;
444
+
431
445
  const int r0 = tgpig.x;
432
446
  const int r1 = tgpig.y;
433
447
  const int im = tgpig.z;
448
+
434
449
  const int first_row = (r0 * nsg + sgitg) * nr;
450
+
435
451
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
452
+
436
453
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
437
454
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
438
- float yl[16]; // src1 vector cache
439
- float sumf[nr]={0.f};
440
455
 
441
- const int ix = tiisg/2;
442
- const int il = 8*(tiisg%2);
456
+ float yl[16]; // src1 vector cache
457
+ float sumf[nr] = {0.f};
458
+
459
+ const int ix = (tiisg/2);
460
+ const int il = (tiisg%2)*8;
443
461
 
444
462
  device const float * yb = y + ix * QK4_0 + il;
445
463
 
@@ -450,6 +468,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
450
468
  sumy += yb[i] + yb[i+1];
451
469
  yl[i+0] = yb[i+ 0];
452
470
  yl[i+1] = yb[i+ 1]/256.f;
471
+
453
472
  sumy += yb[i+16] + yb[i+17];
454
473
  yl[i+8] = yb[i+16]/16.f;
455
474
  yl[i+9] = yb[i+17]/4096.f;
@@ -465,12 +484,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
465
484
  for (int row = 0; row < nr; ++row) {
466
485
  const float tot = simd_sum(sumf[row]);
467
486
  if (tiisg == 0 && first_row + row < ne01) {
468
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
487
+ dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
469
488
  }
470
489
  }
471
490
  }
472
491
 
473
- kernel void kernel_mul_mat_q4_0_f32(
492
+ kernel void kernel_mul_mv_q4_0_f32(
474
493
  device const void * src0,
475
494
  device const float * src1,
476
495
  device float * dst,
@@ -483,12 +502,12 @@ kernel void kernel_mul_mat_q4_0_f32(
483
502
  constant int64_t & ne1[[buffer(16)]],
484
503
  constant uint & gqa[[buffer(17)]],
485
504
  uint3 tgpig[[threadgroup_position_in_grid]],
486
- uint tiisg[[thread_index_in_simdgroup]],
487
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
505
+ uint tiisg[[thread_index_in_simdgroup]],
506
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
488
507
  mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
489
508
  }
490
509
 
491
- kernel void kernel_mul_mat_q4_1_f32(
510
+ kernel void kernel_mul_mv_q4_1_f32(
492
511
  device const void * src0,
493
512
  device const float * src1,
494
513
  device float * dst,
@@ -508,7 +527,7 @@ kernel void kernel_mul_mat_q4_1_f32(
508
527
 
509
528
  #define NB_Q8_0 8
510
529
 
511
- kernel void kernel_mul_mat_q8_0_f32(
530
+ kernel void kernel_mul_mv_q8_0_f32(
512
531
  device const void * src0,
513
532
  device const float * src1,
514
533
  device float * dst,
@@ -572,7 +591,7 @@ kernel void kernel_mul_mat_q8_0_f32(
572
591
 
573
592
  #define N_F32_F32 4
574
593
 
575
- kernel void kernel_mul_mat_f32_f32(
594
+ kernel void kernel_mul_mv_f32_f32(
576
595
  device const char * src0,
577
596
  device const char * src1,
578
597
  device float * dst,
@@ -643,7 +662,7 @@ kernel void kernel_mul_mat_f32_f32(
643
662
  }
644
663
  }
645
664
 
646
- kernel void kernel_mul_mat_f16_f32_1row(
665
+ kernel void kernel_mul_mv_f16_f32_1row(
647
666
  device const char * src0,
648
667
  device const char * src1,
649
668
  device float * dst,
@@ -662,7 +681,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
662
681
  constant int64_t & ne0,
663
682
  constant int64_t & ne1,
664
683
  uint3 tgpig[[threadgroup_position_in_grid]],
665
- uint tiisg[[thread_index_in_simdgroup]]) {
684
+ uint tiisg[[thread_index_in_simdgroup]]) {
666
685
 
667
686
  const int64_t r0 = tgpig.x;
668
687
  const int64_t r1 = tgpig.y;
@@ -697,7 +716,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
697
716
 
698
717
  #define N_F16_F32 4
699
718
 
700
- kernel void kernel_mul_mat_f16_f32(
719
+ kernel void kernel_mul_mv_f16_f32(
701
720
  device const char * src0,
702
721
  device const char * src1,
703
722
  device float * dst,
@@ -769,7 +788,7 @@ kernel void kernel_mul_mat_f16_f32(
769
788
  }
770
789
 
771
790
  // Assumes row size (ne00) is a multiple of 4
772
- kernel void kernel_mul_mat_f16_f32_l4(
791
+ kernel void kernel_mul_mv_f16_f32_l4(
773
792
  device const char * src0,
774
793
  device const char * src1,
775
794
  device float * dst,
@@ -830,7 +849,9 @@ kernel void kernel_alibi_f32(
830
849
  constant uint64_t & nb1,
831
850
  constant uint64_t & nb2,
832
851
  constant uint64_t & nb3,
833
- constant float & m0,
852
+ constant float & m0,
853
+ constant float & m1,
854
+ constant int & n_heads_log2_floor,
834
855
  uint3 tgpig[[threadgroup_position_in_grid]],
835
856
  uint3 tpitg[[thread_position_in_threadgroup]],
836
857
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -846,7 +867,12 @@ kernel void kernel_alibi_f32(
846
867
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
847
868
 
848
869
  device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
849
- float m_k = pow(m0, i2 + 1);
870
+ float m_k;
871
+ if (i2 < n_heads_log2_floor) {
872
+ m_k = pow(m0, i2 + 1);
873
+ } else {
874
+ m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
875
+ }
850
876
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
851
877
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
852
878
  dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
@@ -1091,6 +1117,62 @@ kernel void kernel_cpy_f32_f32(
1091
1117
  }
1092
1118
  }
1093
1119
 
1120
+ kernel void kernel_concat(
1121
+ device const char * src0,
1122
+ device const char * src1,
1123
+ device char * dst,
1124
+ constant int64_t & ne00,
1125
+ constant int64_t & ne01,
1126
+ constant int64_t & ne02,
1127
+ constant int64_t & ne03,
1128
+ constant uint64_t & nb00,
1129
+ constant uint64_t & nb01,
1130
+ constant uint64_t & nb02,
1131
+ constant uint64_t & nb03,
1132
+ constant int64_t & ne10,
1133
+ constant int64_t & ne11,
1134
+ constant int64_t & ne12,
1135
+ constant int64_t & ne13,
1136
+ constant uint64_t & nb10,
1137
+ constant uint64_t & nb11,
1138
+ constant uint64_t & nb12,
1139
+ constant uint64_t & nb13,
1140
+ constant int64_t & ne0,
1141
+ constant int64_t & ne1,
1142
+ constant int64_t & ne2,
1143
+ constant int64_t & ne3,
1144
+ constant uint64_t & nb0,
1145
+ constant uint64_t & nb1,
1146
+ constant uint64_t & nb2,
1147
+ constant uint64_t & nb3,
1148
+ uint3 tgpig[[threadgroup_position_in_grid]],
1149
+ uint3 tpitg[[thread_position_in_threadgroup]],
1150
+ uint3 ntg[[threads_per_threadgroup]]) {
1151
+
1152
+ const int64_t i03 = tgpig.z;
1153
+ const int64_t i02 = tgpig.y;
1154
+ const int64_t i01 = tgpig.x;
1155
+
1156
+ const int64_t i13 = i03 % ne13;
1157
+ const int64_t i12 = i02 % ne12;
1158
+ const int64_t i11 = i01 % ne11;
1159
+
1160
+ device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1161
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1162
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1163
+
1164
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1165
+ if (i02 < ne02) {
1166
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
1167
+ src0_ptr += ntg.x*nb00;
1168
+ } else {
1169
+ ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
1170
+ src1_ptr += ntg.x*nb10;
1171
+ }
1172
+ dst_ptr += ntg.x*nb0;
1173
+ }
1174
+ }
1175
+
1094
1176
  //============================================ k-quants ======================================================
1095
1177
 
1096
1178
  #ifndef QK_K
@@ -1183,7 +1265,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1183
1265
 
1184
1266
  //====================================== dot products =========================
1185
1267
 
1186
- kernel void kernel_mul_mat_q2_K_f32(
1268
+ kernel void kernel_mul_mv_q2_K_f32(
1187
1269
  device const void * src0,
1188
1270
  device const float * src1,
1189
1271
  device float * dst,
@@ -1327,7 +1409,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1327
1409
  }
1328
1410
 
1329
1411
  #if QK_K == 256
1330
- kernel void kernel_mul_mat_q3_K_f32(
1412
+ kernel void kernel_mul_mv_q3_K_f32(
1331
1413
  device const void * src0,
1332
1414
  device const float * src1,
1333
1415
  device float * dst,
@@ -1479,7 +1561,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1479
1561
  }
1480
1562
  }
1481
1563
  #else
1482
- kernel void kernel_mul_mat_q3_K_f32(
1564
+ kernel void kernel_mul_mv_q3_K_f32(
1483
1565
  device const void * src0,
1484
1566
  device const float * src1,
1485
1567
  device float * dst,
@@ -1550,7 +1632,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1550
1632
  #endif
1551
1633
 
1552
1634
  #if QK_K == 256
1553
- kernel void kernel_mul_mat_q4_K_f32(
1635
+ kernel void kernel_mul_mv_q4_K_f32(
1554
1636
  device const void * src0,
1555
1637
  device const float * src1,
1556
1638
  device float * dst,
@@ -1656,7 +1738,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1656
1738
  }
1657
1739
  }
1658
1740
  #else
1659
- kernel void kernel_mul_mat_q4_K_f32(
1741
+ kernel void kernel_mul_mv_q4_K_f32(
1660
1742
  device const void * src0,
1661
1743
  device const float * src1,
1662
1744
  device float * dst,
@@ -1745,7 +1827,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1745
1827
  }
1746
1828
  #endif
1747
1829
 
1748
- kernel void kernel_mul_mat_q5_K_f32(
1830
+ kernel void kernel_mul_mv_q5_K_f32(
1749
1831
  device const void * src0,
1750
1832
  device const float * src1,
1751
1833
  device float * dst,
@@ -1918,7 +2000,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1918
2000
 
1919
2001
  }
1920
2002
 
1921
- kernel void kernel_mul_mat_q6_K_f32(
2003
+ kernel void kernel_mul_mv_q6_K_f32(
1922
2004
  device const void * src0,
1923
2005
  device const float * src1,
1924
2006
  device float * dst,
@@ -2256,7 +2338,7 @@ kernel void kernel_get_rows(
2256
2338
  }
2257
2339
 
2258
2340
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
2259
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
2341
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
2260
2342
  #define BLOCK_SIZE_K 32
2261
2343
  #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
2262
2344
  #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@@ -2293,9 +2375,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
2293
2375
  const uint r0 = tgpig.y;
2294
2376
  const uint r1 = tgpig.x;
2295
2377
  const uint im = tgpig.z;
2378
+
2296
2379
  // if this block is of 64x32 shape or smaller
2297
2380
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2298
2381
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
2382
+
2299
2383
  // a thread shouldn't load data outside of the matrix
2300
2384
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2301
2385
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2319,26 +2403,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
2319
2403
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2320
2404
 
2321
2405
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2322
- //load data and store to threadgroup memory
2406
+ // load data and store to threadgroup memory
2323
2407
  half4x4 temp_a;
2324
2408
  dequantize_func(x, il, temp_a);
2325
2409
  threadgroup_barrier(mem_flags::mem_threadgroup);
2410
+
2326
2411
  #pragma unroll(16)
2327
2412
  for (int i = 0; i < 16; i++) {
2328
2413
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2329
- + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
2330
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2414
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
2415
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2331
2416
  }
2332
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
2333
- = *((device float2x4 *)y);
2417
+
2418
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
2419
+
2334
2420
  il = (il + 2 < nl) ? il + 2 : il % 2;
2335
2421
  x = (il < 2) ? x + (2+nl-1)/nl : x;
2336
2422
  y += BLOCK_SIZE_K;
2337
2423
 
2338
2424
  threadgroup_barrier(mem_flags::mem_threadgroup);
2339
- //load matrices from threadgroup memory and conduct outer products
2425
+
2426
+ // load matrices from threadgroup memory and conduct outer products
2340
2427
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
2341
2428
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
2429
+
2342
2430
  #pragma unroll(4)
2343
2431
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
2344
2432
  #pragma unroll(4)
@@ -2353,6 +2441,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2353
2441
 
2354
2442
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2355
2443
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
2444
+
2356
2445
  #pragma unroll(8)
2357
2446
  for (int i = 0; i < 8; i++){
2358
2447
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2361,25 +2450,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
2361
2450
  }
2362
2451
 
2363
2452
  if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2364
- device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
2365
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
2453
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2454
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
2366
2455
  for (int i = 0; i < 8; i++) {
2367
2456
  simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
2368
2457
  }
2369
2458
  } else {
2370
2459
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2371
2460
  threadgroup_barrier(mem_flags::mem_threadgroup);
2372
- threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
2461
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
2373
2462
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2374
2463
  for (int i = 0; i < 8; i++) {
2375
2464
  simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2376
2465
  }
2377
2466
 
2378
2467
  threadgroup_barrier(mem_flags::mem_threadgroup);
2379
- device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2380
- if (sgitg==0) {
2468
+
2469
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2470
+ if (sgitg == 0) {
2381
2471
  for (int i = 0; i < n_rows; i++) {
2382
- for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
2472
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
2383
2473
  *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2384
2474
  }
2385
2475
  }