whisper.rn 0.4.0-rc.6 → 0.4.0-rc.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml.c CHANGED
@@ -132,7 +132,7 @@ void wsp_ggml_print_backtrace(void) {
132
132
  "-ex", "bt -frame-info source-and-location",
133
133
  "-ex", "detach",
134
134
  "-ex", "quit",
135
- NULL);
135
+ (char *) NULL);
136
136
  } else {
137
137
  waitpid(pid, NULL, 0);
138
138
  }
@@ -573,6 +573,28 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
573
573
  .vec_dot = wsp_ggml_vec_dot_q6_K_q8_K,
574
574
  .vec_dot_type = WSP_GGML_TYPE_Q8_K,
575
575
  },
576
+ [WSP_GGML_TYPE_IQ2_XXS] = {
577
+ .type_name = "iq2_xxs",
578
+ .blck_size = QK_K,
579
+ .type_size = sizeof(block_iq2_xxs),
580
+ .is_quantized = true,
581
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_iq2_xxs,
582
+ .from_float = NULL,
583
+ .from_float_reference = NULL,
584
+ .vec_dot = wsp_ggml_vec_dot_iq2_xxs_q8_K,
585
+ .vec_dot_type = WSP_GGML_TYPE_Q8_K,
586
+ },
587
+ [WSP_GGML_TYPE_IQ2_XS] = {
588
+ .type_name = "iq2_xs",
589
+ .blck_size = QK_K,
590
+ .type_size = sizeof(block_iq2_xs),
591
+ .is_quantized = true,
592
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_iq2_xs,
593
+ .from_float = NULL,
594
+ .from_float_reference = NULL,
595
+ .vec_dot = wsp_ggml_vec_dot_iq2_xs_q8_K,
596
+ .vec_dot_type = WSP_GGML_TYPE_Q8_K,
597
+ },
576
598
  [WSP_GGML_TYPE_Q8_K] = {
577
599
  .type_name = "q8_K",
578
600
  .blck_size = QK_K,
@@ -1962,19 +1984,19 @@ void wsp_ggml_print_objects(const struct wsp_ggml_context * ctx) {
1962
1984
  WSP_GGML_PRINT("%s: --- end ---\n", __func__);
1963
1985
  }
1964
1986
 
1965
- int64_t wsp_ggml_nelements(const struct wsp_ggml_tensor * tensor) {
1987
+ WSP_GGML_CALL int64_t wsp_ggml_nelements(const struct wsp_ggml_tensor * tensor) {
1966
1988
  static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
1967
1989
 
1968
1990
  return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
1969
1991
  }
1970
1992
 
1971
- int64_t wsp_ggml_nrows(const struct wsp_ggml_tensor * tensor) {
1993
+ WSP_GGML_CALL int64_t wsp_ggml_nrows(const struct wsp_ggml_tensor * tensor) {
1972
1994
  static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
1973
1995
 
1974
1996
  return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
1975
1997
  }
1976
1998
 
1977
- size_t wsp_ggml_nbytes(const struct wsp_ggml_tensor * tensor) {
1999
+ WSP_GGML_CALL size_t wsp_ggml_nbytes(const struct wsp_ggml_tensor * tensor) {
1978
2000
  size_t nbytes;
1979
2001
  size_t blck_size = wsp_ggml_blck_size(tensor->type);
1980
2002
  if (blck_size == 1) {
@@ -1997,33 +2019,32 @@ size_t wsp_ggml_nbytes_pad(const struct wsp_ggml_tensor * tensor) {
1997
2019
  return WSP_GGML_PAD(wsp_ggml_nbytes(tensor), WSP_GGML_MEM_ALIGN);
1998
2020
  }
1999
2021
 
2000
- size_t wsp_ggml_nbytes_split(const struct wsp_ggml_tensor * tensor, int nrows_split) {
2001
- static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
2002
-
2003
- return (nrows_split*tensor->ne[0]*wsp_ggml_type_size(tensor->type))/wsp_ggml_blck_size(tensor->type);
2004
- }
2005
-
2006
- int wsp_ggml_blck_size(enum wsp_ggml_type type) {
2022
+ WSP_GGML_CALL int wsp_ggml_blck_size(enum wsp_ggml_type type) {
2007
2023
  return type_traits[type].blck_size;
2008
2024
  }
2009
2025
 
2010
- size_t wsp_ggml_type_size(enum wsp_ggml_type type) {
2026
+ WSP_GGML_CALL size_t wsp_ggml_type_size(enum wsp_ggml_type type) {
2011
2027
  return type_traits[type].type_size;
2012
2028
  }
2013
2029
 
2014
- float wsp_ggml_type_sizef(enum wsp_ggml_type type) {
2015
- return ((float)(type_traits[type].type_size))/type_traits[type].blck_size;
2030
+ WSP_GGML_CALL size_t wsp_ggml_row_size(enum wsp_ggml_type type, int64_t ne) {
2031
+ assert(ne % wsp_ggml_blck_size(type) == 0);
2032
+ return wsp_ggml_type_size(type)*ne/wsp_ggml_blck_size(type);
2033
+ }
2034
+
2035
+ double wsp_ggml_type_sizef(enum wsp_ggml_type type) {
2036
+ return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
2016
2037
  }
2017
2038
 
2018
- const char * wsp_ggml_type_name(enum wsp_ggml_type type) {
2039
+ WSP_GGML_CALL const char * wsp_ggml_type_name(enum wsp_ggml_type type) {
2019
2040
  return type_traits[type].type_name;
2020
2041
  }
2021
2042
 
2022
- bool wsp_ggml_is_quantized(enum wsp_ggml_type type) {
2043
+ WSP_GGML_CALL bool wsp_ggml_is_quantized(enum wsp_ggml_type type) {
2023
2044
  return type_traits[type].is_quantized;
2024
2045
  }
2025
2046
 
2026
- const char * wsp_ggml_op_name(enum wsp_ggml_op op) {
2047
+ WSP_GGML_CALL const char * wsp_ggml_op_name(enum wsp_ggml_op op) {
2027
2048
  return WSP_GGML_OP_NAME[op];
2028
2049
  }
2029
2050
 
@@ -2035,7 +2056,7 @@ const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op) {
2035
2056
  return WSP_GGML_UNARY_OP_NAME[op];
2036
2057
  }
2037
2058
 
2038
- const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t) {
2059
+ WSP_GGML_CALL const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t) {
2039
2060
  if (t->op == WSP_GGML_OP_UNARY) {
2040
2061
  enum wsp_ggml_unary_op uop = wsp_ggml_get_unary_op(t);
2041
2062
  return wsp_ggml_unary_op_name(uop);
@@ -2045,28 +2066,41 @@ const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t) {
2045
2066
  }
2046
2067
  }
2047
2068
 
2048
- size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor) {
2069
+ WSP_GGML_CALL size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor) {
2049
2070
  return wsp_ggml_type_size(tensor->type);
2050
2071
  }
2051
2072
 
2052
- static inline bool wsp_ggml_is_scalar(const struct wsp_ggml_tensor * tensor) {
2073
+ bool wsp_ggml_is_scalar(const struct wsp_ggml_tensor * tensor) {
2053
2074
  static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
2054
2075
 
2055
2076
  return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
2056
2077
  }
2057
2078
 
2058
- static inline bool wsp_ggml_is_vector(const struct wsp_ggml_tensor * tensor) {
2079
+ bool wsp_ggml_is_vector(const struct wsp_ggml_tensor * tensor) {
2059
2080
  static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
2060
2081
 
2061
2082
  return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
2062
2083
  }
2063
2084
 
2064
- static inline bool wsp_ggml_is_matrix(const struct wsp_ggml_tensor * tensor) {
2085
+ bool wsp_ggml_is_matrix(const struct wsp_ggml_tensor * tensor) {
2065
2086
  static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
2066
2087
 
2067
2088
  return tensor->ne[2] == 1 && tensor->ne[3] == 1;
2068
2089
  }
2069
2090
 
2091
+ bool wsp_ggml_is_3d(const struct wsp_ggml_tensor * tensor) {
2092
+ return tensor->ne[3] == 1;
2093
+ }
2094
+
2095
+ int wsp_ggml_n_dims(const struct wsp_ggml_tensor * tensor) {
2096
+ for (int i = WSP_GGML_MAX_DIMS - 1; i >= 1; --i) {
2097
+ if (tensor->ne[i] > 1) {
2098
+ return i + 1;
2099
+ }
2100
+ }
2101
+ return 1;
2102
+ }
2103
+
2070
2104
  static inline bool wsp_ggml_can_mul_mat(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1) {
2071
2105
  static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
2072
2106
 
@@ -2099,6 +2133,8 @@ enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype) {
2099
2133
  case WSP_GGML_FTYPE_MOSTLY_Q4_K: wtype = WSP_GGML_TYPE_Q4_K; break;
2100
2134
  case WSP_GGML_FTYPE_MOSTLY_Q5_K: wtype = WSP_GGML_TYPE_Q5_K; break;
2101
2135
  case WSP_GGML_FTYPE_MOSTLY_Q6_K: wtype = WSP_GGML_TYPE_Q6_K; break;
2136
+ case WSP_GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = WSP_GGML_TYPE_IQ2_XXS; break;
2137
+ case WSP_GGML_FTYPE_MOSTLY_IQ2_XS: wtype = WSP_GGML_TYPE_IQ2_XS; break;
2102
2138
  case WSP_GGML_FTYPE_UNKNOWN: wtype = WSP_GGML_TYPE_COUNT; break;
2103
2139
  case WSP_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = WSP_GGML_TYPE_COUNT; break;
2104
2140
  }
@@ -2112,11 +2148,11 @@ size_t wsp_ggml_tensor_overhead(void) {
2112
2148
  return WSP_GGML_OBJECT_SIZE + WSP_GGML_TENSOR_SIZE;
2113
2149
  }
2114
2150
 
2115
- bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor) {
2151
+ WSP_GGML_CALL bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor) {
2116
2152
  return tensor->nb[0] > tensor->nb[1];
2117
2153
  }
2118
2154
 
2119
- bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor) {
2155
+ WSP_GGML_CALL bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor) {
2120
2156
  static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
2121
2157
 
2122
2158
  return
@@ -2135,7 +2171,7 @@ static inline bool wsp_ggml_is_contiguous_except_dim_1(const struct wsp_ggml_ten
2135
2171
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
2136
2172
  }
2137
2173
 
2138
- bool wsp_ggml_is_permuted(const struct wsp_ggml_tensor * tensor) {
2174
+ WSP_GGML_CALL bool wsp_ggml_is_permuted(const struct wsp_ggml_tensor * tensor) {
2139
2175
  static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
2140
2176
 
2141
2177
  return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
@@ -2312,6 +2348,10 @@ struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params) {
2312
2348
  }
2313
2349
 
2314
2350
  void wsp_ggml_free(struct wsp_ggml_context * ctx) {
2351
+ if (ctx == NULL) {
2352
+ return;
2353
+ }
2354
+
2315
2355
  // make this function thread safe
2316
2356
  wsp_ggml_critical_section_start();
2317
2357
 
@@ -2371,20 +2411,8 @@ size_t wsp_ggml_get_mem_size(const struct wsp_ggml_context * ctx) {
2371
2411
  size_t wsp_ggml_get_max_tensor_size(const struct wsp_ggml_context * ctx) {
2372
2412
  size_t max_size = 0;
2373
2413
 
2374
- struct wsp_ggml_object * obj = ctx->objects_begin;
2375
-
2376
- while (obj != NULL) {
2377
- if (obj->type == WSP_GGML_OBJECT_TENSOR) {
2378
- struct wsp_ggml_tensor * tensor = (struct wsp_ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
2379
-
2380
- const size_t size = wsp_ggml_nbytes(tensor);
2381
-
2382
- if (max_size < size) {
2383
- max_size = size;
2384
- }
2385
- }
2386
-
2387
- obj = obj->next;
2414
+ for (struct wsp_ggml_tensor * tensor = wsp_ggml_get_first_tensor(ctx); tensor != NULL; tensor = wsp_ggml_get_next_tensor(ctx, tensor)) {
2415
+ max_size = MAX(max_size, wsp_ggml_nbytes(tensor));
2388
2416
  }
2389
2417
 
2390
2418
  return max_size;
@@ -2473,7 +2501,7 @@ static struct wsp_ggml_tensor * wsp_ggml_new_tensor_impl(
2473
2501
  view_src = view_src->view_src;
2474
2502
  }
2475
2503
 
2476
- size_t data_size = wsp_ggml_type_size(type)*(ne[0]/wsp_ggml_blck_size(type));
2504
+ size_t data_size = wsp_ggml_row_size(type, ne[0]);
2477
2505
  for (int i = 1; i < n_dims; i++) {
2478
2506
  data_size *= ne[i];
2479
2507
  }
@@ -2516,7 +2544,6 @@ static struct wsp_ggml_tensor * wsp_ggml_new_tensor_impl(
2516
2544
  /*.type =*/ type,
2517
2545
  /*.backend =*/ WSP_GGML_BACKEND_CPU,
2518
2546
  /*.buffer =*/ NULL,
2519
- /*.n_dims =*/ n_dims,
2520
2547
  /*.ne =*/ { 1, 1, 1, 1 },
2521
2548
  /*.nb =*/ { 0, 0, 0, 0 },
2522
2549
  /*.op =*/ WSP_GGML_OP_NONE,
@@ -2623,7 +2650,7 @@ struct wsp_ggml_tensor * wsp_ggml_new_f32(struct wsp_ggml_context * ctx, float v
2623
2650
  }
2624
2651
 
2625
2652
  struct wsp_ggml_tensor * wsp_ggml_dup_tensor(struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src) {
2626
- return wsp_ggml_new_tensor(ctx, src->type, src->n_dims, src->ne);
2653
+ return wsp_ggml_new_tensor(ctx, src->type, WSP_GGML_MAX_DIMS, src->ne);
2627
2654
  }
2628
2655
 
2629
2656
  static void wsp_ggml_set_op_params(struct wsp_ggml_tensor * tensor, const void * params, size_t params_size) {
@@ -3046,7 +3073,7 @@ float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor) {
3046
3073
  return (float *)(tensor->data);
3047
3074
  }
3048
3075
 
3049
- enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor) {
3076
+ WSP_GGML_CALL enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor) {
3050
3077
  WSP_GGML_ASSERT(tensor->op == WSP_GGML_OP_UNARY);
3051
3078
  return (enum wsp_ggml_unary_op) wsp_ggml_get_op_params_i32(tensor, 0);
3052
3079
  }
@@ -3072,7 +3099,7 @@ struct wsp_ggml_tensor * wsp_ggml_format_name(struct wsp_ggml_tensor * tensor, c
3072
3099
  struct wsp_ggml_tensor * wsp_ggml_view_tensor(
3073
3100
  struct wsp_ggml_context * ctx,
3074
3101
  struct wsp_ggml_tensor * src) {
3075
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src, 0);
3102
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, src->type, WSP_GGML_MAX_DIMS, src->ne, src, 0);
3076
3103
  wsp_ggml_format_name(result, "%s (view)", src->name);
3077
3104
 
3078
3105
  for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
@@ -3082,7 +3109,7 @@ struct wsp_ggml_tensor * wsp_ggml_view_tensor(
3082
3109
  return result;
3083
3110
  }
3084
3111
 
3085
- struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(struct wsp_ggml_context * ctx) {
3112
+ struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(const struct wsp_ggml_context * ctx) {
3086
3113
  struct wsp_ggml_object * obj = ctx->objects_begin;
3087
3114
 
3088
3115
  char * const mem_buffer = ctx->mem_buffer;
@@ -3098,7 +3125,7 @@ struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(struct wsp_ggml_context * ctx
3098
3125
  return NULL;
3099
3126
  }
3100
3127
 
3101
- struct wsp_ggml_tensor * wsp_ggml_get_next_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor) {
3128
+ struct wsp_ggml_tensor * wsp_ggml_get_next_tensor(const struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor) {
3102
3129
  struct wsp_ggml_object * obj = (struct wsp_ggml_object *) ((char *)tensor - WSP_GGML_OBJECT_SIZE);
3103
3130
  obj = obj->next;
3104
3131
 
@@ -3230,10 +3257,10 @@ static struct wsp_ggml_tensor * wsp_ggml_add_cast_impl(
3230
3257
  is_node = true;
3231
3258
  }
3232
3259
 
3233
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, type, a->n_dims, a->ne);
3260
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, type, WSP_GGML_MAX_DIMS, a->ne);
3234
3261
 
3235
3262
  result->op = WSP_GGML_OP_ADD;
3236
- result->grad = is_node ? wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
3263
+ result->grad = is_node ? wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, WSP_GGML_MAX_DIMS, a->ne) : NULL;
3237
3264
  result->src[0] = a;
3238
3265
  result->src[1] = b;
3239
3266
 
@@ -3602,12 +3629,12 @@ struct wsp_ggml_tensor * wsp_ggml_sum_rows(
3602
3629
  is_node = true;
3603
3630
  }
3604
3631
 
3605
- int64_t ne[4] = {1,1,1,1};
3606
- for (int i=1; i<a->n_dims; ++i) {
3632
+ int64_t ne[WSP_GGML_MAX_DIMS] = { 1 };
3633
+ for (int i = 1; i < WSP_GGML_MAX_DIMS; ++i) {
3607
3634
  ne[i] = a->ne[i];
3608
3635
  }
3609
3636
 
3610
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, a->n_dims, ne);
3637
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, WSP_GGML_MAX_DIMS, ne);
3611
3638
 
3612
3639
  result->op = WSP_GGML_OP_SUM_ROWS;
3613
3640
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -3628,8 +3655,8 @@ struct wsp_ggml_tensor * wsp_ggml_mean(
3628
3655
  is_node = true;
3629
3656
  }
3630
3657
 
3631
- int64_t ne[WSP_GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
3632
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, a->n_dims, ne);
3658
+ int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] };
3659
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
3633
3660
 
3634
3661
  result->op = WSP_GGML_OP_MEAN;
3635
3662
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -3651,8 +3678,7 @@ struct wsp_ggml_tensor * wsp_ggml_argmax(
3651
3678
  is_node = true;
3652
3679
  }
3653
3680
 
3654
- int64_t ne[WSP_GGML_MAX_DIMS] = { a->ne[1], 1, 1, 1 };
3655
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_I32, a->n_dims, ne);
3681
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, a->ne[1]);
3656
3682
 
3657
3683
  result->op = WSP_GGML_OP_ARGMAX;
3658
3684
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -3675,7 +3701,7 @@ struct wsp_ggml_tensor * wsp_ggml_repeat(
3675
3701
  is_node = true;
3676
3702
  }
3677
3703
 
3678
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
3704
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, WSP_GGML_MAX_DIMS, b->ne);
3679
3705
 
3680
3706
  result->op = WSP_GGML_OP_REPEAT;
3681
3707
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -3702,7 +3728,7 @@ struct wsp_ggml_tensor * wsp_ggml_repeat_back(
3702
3728
  return a;
3703
3729
  }
3704
3730
 
3705
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
3731
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, WSP_GGML_MAX_DIMS, b->ne);
3706
3732
 
3707
3733
  result->op = WSP_GGML_OP_REPEAT_BACK;
3708
3734
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -4043,7 +4069,6 @@ static struct wsp_ggml_tensor * wsp_ggml_group_norm_impl(
4043
4069
  result->op = WSP_GGML_OP_GROUP_NORM;
4044
4070
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
4045
4071
  result->src[0] = a;
4046
- result->src[1] = NULL; // TODO: maybe store epsilon here?
4047
4072
 
4048
4073
  return result;
4049
4074
  }
@@ -4078,7 +4103,7 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat(
4078
4103
  }
4079
4104
 
4080
4105
  const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
4081
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
4106
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
4082
4107
 
4083
4108
  result->op = WSP_GGML_OP_MUL_MAT;
4084
4109
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -4088,6 +4113,14 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat(
4088
4113
  return result;
4089
4114
  }
4090
4115
 
4116
+ void wsp_ggml_mul_mat_set_prec(
4117
+ struct wsp_ggml_tensor * a,
4118
+ enum wsp_ggml_prec prec) {
4119
+ const int32_t prec_i32 = (int32_t) prec;
4120
+
4121
+ wsp_ggml_set_op_params_i32(a, 0, prec_i32);
4122
+ }
4123
+
4091
4124
  // wsp_ggml_mul_mat_id
4092
4125
 
4093
4126
  struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
@@ -4112,7 +4145,7 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
4112
4145
  }
4113
4146
 
4114
4147
  const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
4115
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
4148
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
4116
4149
 
4117
4150
  wsp_ggml_set_op_params_i32(result, 0, id);
4118
4151
  wsp_ggml_set_op_params_i32(result, 1, n_as);
@@ -4150,7 +4183,7 @@ struct wsp_ggml_tensor * wsp_ggml_out_prod(
4150
4183
 
4151
4184
  // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
4152
4185
  const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
4153
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
4186
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
4154
4187
 
4155
4188
  result->op = WSP_GGML_OP_OUT_PROD;
4156
4189
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -4165,23 +4198,23 @@ struct wsp_ggml_tensor * wsp_ggml_out_prod(
4165
4198
  static struct wsp_ggml_tensor * wsp_ggml_scale_impl(
4166
4199
  struct wsp_ggml_context * ctx,
4167
4200
  struct wsp_ggml_tensor * a,
4168
- struct wsp_ggml_tensor * b,
4201
+ float s,
4169
4202
  bool inplace) {
4170
- WSP_GGML_ASSERT(wsp_ggml_is_scalar(b));
4171
4203
  WSP_GGML_ASSERT(wsp_ggml_is_padded_1d(a));
4172
4204
 
4173
4205
  bool is_node = false;
4174
4206
 
4175
- if (a->grad || b->grad) {
4207
+ if (a->grad) {
4176
4208
  is_node = true;
4177
4209
  }
4178
4210
 
4179
4211
  struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
4180
4212
 
4213
+ wsp_ggml_set_op_params(result, &s, sizeof(s));
4214
+
4181
4215
  result->op = WSP_GGML_OP_SCALE;
4182
4216
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
4183
4217
  result->src[0] = a;
4184
- result->src[1] = b;
4185
4218
 
4186
4219
  return result;
4187
4220
  }
@@ -4189,15 +4222,15 @@ static struct wsp_ggml_tensor * wsp_ggml_scale_impl(
4189
4222
  struct wsp_ggml_tensor * wsp_ggml_scale(
4190
4223
  struct wsp_ggml_context * ctx,
4191
4224
  struct wsp_ggml_tensor * a,
4192
- struct wsp_ggml_tensor * b) {
4193
- return wsp_ggml_scale_impl(ctx, a, b, false);
4225
+ float s) {
4226
+ return wsp_ggml_scale_impl(ctx, a, s, false);
4194
4227
  }
4195
4228
 
4196
4229
  struct wsp_ggml_tensor * wsp_ggml_scale_inplace(
4197
4230
  struct wsp_ggml_context * ctx,
4198
4231
  struct wsp_ggml_tensor * a,
4199
- struct wsp_ggml_tensor * b) {
4200
- return wsp_ggml_scale_impl(ctx, a, b, true);
4232
+ float s) {
4233
+ return wsp_ggml_scale_impl(ctx, a, s, true);
4201
4234
  }
4202
4235
 
4203
4236
  // wsp_ggml_set
@@ -4294,13 +4327,13 @@ struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace(
4294
4327
  static struct wsp_ggml_tensor * wsp_ggml_cpy_impl(
4295
4328
  struct wsp_ggml_context * ctx,
4296
4329
  struct wsp_ggml_tensor * a,
4297
- struct wsp_ggml_tensor * b,
4298
- bool inplace) {
4330
+ struct wsp_ggml_tensor * b) {
4299
4331
  WSP_GGML_ASSERT(wsp_ggml_nelements(a) == wsp_ggml_nelements(b));
4300
4332
 
4301
4333
  bool is_node = false;
4302
4334
 
4303
- if (!inplace && (a->grad || b->grad)) {
4335
+ if (a->grad || b->grad) {
4336
+ // inplace is false and either one have a grad
4304
4337
  is_node = true;
4305
4338
  }
4306
4339
 
@@ -4324,29 +4357,38 @@ struct wsp_ggml_tensor * wsp_ggml_cpy(
4324
4357
  struct wsp_ggml_context * ctx,
4325
4358
  struct wsp_ggml_tensor * a,
4326
4359
  struct wsp_ggml_tensor * b) {
4327
- return wsp_ggml_cpy_impl(ctx, a, b, false);
4360
+ return wsp_ggml_cpy_impl(ctx, a, b);
4328
4361
  }
4329
4362
 
4330
- struct wsp_ggml_tensor * wsp_ggml_cpy_inplace(
4363
+ struct wsp_ggml_tensor * wsp_ggml_cast(
4331
4364
  struct wsp_ggml_context * ctx,
4332
- struct wsp_ggml_tensor * a,
4333
- struct wsp_ggml_tensor * b) {
4334
- return wsp_ggml_cpy_impl(ctx, a, b, true);
4365
+ struct wsp_ggml_tensor * a,
4366
+ enum wsp_ggml_type type) {
4367
+ bool is_node = false;
4368
+
4369
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, type, WSP_GGML_MAX_DIMS, a->ne);
4370
+ wsp_ggml_format_name(result, "%s (copy)", a->name);
4371
+
4372
+ result->op = WSP_GGML_OP_CPY;
4373
+ result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
4374
+ result->src[0] = a;
4375
+ result->src[1] = result;
4376
+
4377
+ return result;
4335
4378
  }
4336
4379
 
4337
4380
  // wsp_ggml_cont
4338
4381
 
4339
4382
  static struct wsp_ggml_tensor * wsp_ggml_cont_impl(
4340
4383
  struct wsp_ggml_context * ctx,
4341
- struct wsp_ggml_tensor * a,
4342
- bool inplace) {
4384
+ struct wsp_ggml_tensor * a) {
4343
4385
  bool is_node = false;
4344
4386
 
4345
- if (!inplace && a->grad) {
4387
+ if (a->grad) {
4346
4388
  is_node = true;
4347
4389
  }
4348
4390
 
4349
- struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
4391
+ struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
4350
4392
  wsp_ggml_format_name(result, "%s (cont)", a->name);
4351
4393
 
4352
4394
  result->op = WSP_GGML_OP_CONT;
@@ -4359,13 +4401,7 @@ static struct wsp_ggml_tensor * wsp_ggml_cont_impl(
4359
4401
  struct wsp_ggml_tensor * wsp_ggml_cont(
4360
4402
  struct wsp_ggml_context * ctx,
4361
4403
  struct wsp_ggml_tensor * a) {
4362
- return wsp_ggml_cont_impl(ctx, a, false);
4363
- }
4364
-
4365
- struct wsp_ggml_tensor * wsp_ggml_cont_inplace(
4366
- struct wsp_ggml_context * ctx,
4367
- struct wsp_ggml_tensor * a) {
4368
- return wsp_ggml_cont_impl(ctx, a, true);
4404
+ return wsp_ggml_cont_impl(ctx, a);
4369
4405
  }
4370
4406
 
4371
4407
  // make contiguous, with new shape
@@ -4435,7 +4471,7 @@ struct wsp_ggml_tensor * wsp_ggml_reshape(
4435
4471
  //WSP_GGML_ASSERT(false);
4436
4472
  }
4437
4473
 
4438
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a, 0);
4474
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, WSP_GGML_MAX_DIMS, b->ne, a, 0);
4439
4475
  wsp_ggml_format_name(result, "%s (reshaped)", a->name);
4440
4476
 
4441
4477
  result->op = WSP_GGML_OP_RESHAPE;
@@ -4761,8 +4797,11 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows(
4761
4797
  }
4762
4798
 
4763
4799
  // TODO: implement non F32 return
4764
- //struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
4765
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, WSP_GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
4800
+ enum wsp_ggml_type type = WSP_GGML_TYPE_F32;
4801
+ if (a->type == WSP_GGML_TYPE_I32) {
4802
+ type = a->type;
4803
+ }
4804
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
4766
4805
 
4767
4806
  result->op = WSP_GGML_OP_GET_ROWS;
4768
4807
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -4813,7 +4852,7 @@ struct wsp_ggml_tensor * wsp_ggml_diag(
4813
4852
  }
4814
4853
 
4815
4854
  const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };
4816
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, MAX(a->n_dims, 2), ne);
4855
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, 4, ne);
4817
4856
 
4818
4857
  result->op = WSP_GGML_OP_DIAG;
4819
4858
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -5460,7 +5499,7 @@ struct wsp_ggml_tensor * wsp_ggml_pool_1d(
5460
5499
  is_node = true;
5461
5500
  }
5462
5501
 
5463
- const int64_t ne[3] = {
5502
+ const int64_t ne[2] = {
5464
5503
  wsp_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
5465
5504
  a->ne[1],
5466
5505
  };
@@ -5535,7 +5574,6 @@ static struct wsp_ggml_tensor * wsp_ggml_upscale_impl(
5535
5574
  result->op_params[0] = scale_factor;
5536
5575
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5537
5576
  result->src[0] = a;
5538
- result->src[1] = NULL;
5539
5577
 
5540
5578
  return result;
5541
5579
  }
@@ -5579,7 +5617,7 @@ struct wsp_ggml_tensor * wsp_ggml_argsort(
5579
5617
  enum wsp_ggml_sort_order order) {
5580
5618
  bool is_node = false;
5581
5619
 
5582
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_I32, a->n_dims, a->ne);
5620
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_I32, WSP_GGML_MAX_DIMS, a->ne);
5583
5621
 
5584
5622
  wsp_ggml_set_op_params_i32(result, 0, (int32_t) order);
5585
5623
 
@@ -5626,7 +5664,7 @@ struct wsp_ggml_tensor * wsp_ggml_flash_attn(
5626
5664
  }
5627
5665
 
5628
5666
  //struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, q);
5629
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, q->n_dims, q->ne);
5667
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, WSP_GGML_MAX_DIMS, q->ne);
5630
5668
 
5631
5669
  int32_t t = masked ? 1 : 0;
5632
5670
  wsp_ggml_set_op_params(result, &t, sizeof(t));
@@ -5659,7 +5697,7 @@ struct wsp_ggml_tensor * wsp_ggml_flash_ff(
5659
5697
  }
5660
5698
 
5661
5699
  //struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
5662
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, a->n_dims, a->ne);
5700
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, WSP_GGML_MAX_DIMS, a->ne);
5663
5701
 
5664
5702
  result->op = WSP_GGML_OP_FLASH_FF;
5665
5703
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -5775,7 +5813,6 @@ struct wsp_ggml_tensor * wsp_ggml_win_part(
5775
5813
  const int np = npx*npy;
5776
5814
 
5777
5815
  const int64_t ne[4] = { a->ne[0], w, w, np, };
5778
-
5779
5816
  struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
5780
5817
 
5781
5818
  int32_t params[] = { npx, npy, w };
@@ -5841,7 +5878,6 @@ struct wsp_ggml_tensor * wsp_ggml_get_rel_pos(
5841
5878
  result->op = WSP_GGML_OP_GET_REL_POS;
5842
5879
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5843
5880
  result->src[0] = a;
5844
- result->src[1] = NULL;
5845
5881
 
5846
5882
  return result;
5847
5883
  }
@@ -6936,14 +6972,165 @@ static void wsp_ggml_compute_forward_dup_f32(
6936
6972
  }
6937
6973
  }
6938
6974
 
6939
- static void wsp_ggml_compute_forward_dup(
6975
+ // A simplified version of wsp_ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
6976
+ static void wsp_ggml_compute_forward_dup_bytes(
6940
6977
  const struct wsp_ggml_compute_params * params,
6941
6978
  const struct wsp_ggml_tensor * src0,
6942
6979
  struct wsp_ggml_tensor * dst) {
6943
- if (wsp_ggml_is_contiguous(src0) && wsp_ggml_is_contiguous(dst) && src0->type == dst->type) {
6980
+ WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
6981
+ WSP_GGML_ASSERT(src0->type == dst->type);
6982
+
6983
+ if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
6984
+ return;
6985
+ }
6986
+
6987
+ if (wsp_ggml_is_contiguous(src0) && wsp_ggml_is_contiguous(dst)) {
6944
6988
  wsp_ggml_compute_forward_dup_same_cont(params, src0, dst);
6945
6989
  return;
6946
6990
  }
6991
+
6992
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS;
6993
+
6994
+ const size_t type_size = wsp_ggml_type_size(src0->type);
6995
+ const int ith = params->ith; // thread index
6996
+ const int nth = params->nth; // number of threads
6997
+
6998
+
6999
+ // parallelize by rows
7000
+ const int nr = ne01;
7001
+ // number of rows per thread
7002
+ const int dr = (nr + nth - 1) / nth;
7003
+ // row range for this thread
7004
+ const int ir0 = dr * ith;
7005
+ const int ir1 = MIN(ir0 + dr, nr);
7006
+
7007
+ if (src0->type == dst->type &&
7008
+ ne00 == ne0 &&
7009
+ nb00 == type_size && nb0 == type_size) {
7010
+ // copy by rows
7011
+ const size_t rs = ne00 * type_size;
7012
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7013
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7014
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7015
+ memcpy(
7016
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
7017
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
7018
+ rs);
7019
+ }
7020
+ }
7021
+ }
7022
+ return;
7023
+ }
7024
+
7025
+ if (wsp_ggml_is_contiguous(dst)) {
7026
+ size_t id = 0;
7027
+ char * dst_ptr = (char *) dst->data;
7028
+ const size_t rs = ne00 * type_size;
7029
+
7030
+ if (nb00 == type_size) {
7031
+ // src0 is contigous on first dimension, copy by rows
7032
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7033
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7034
+ id += rs * ir0;
7035
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7036
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
7037
+ memcpy(dst_ptr + id, src0_ptr, rs);
7038
+ id += rs;
7039
+ }
7040
+ id += rs * (ne01 - ir1);
7041
+ }
7042
+ }
7043
+ } else {
7044
+ //printf("%s: this is not optimal - fix me\n", __func__);
7045
+
7046
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7047
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7048
+ id += rs * ir0;
7049
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7050
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7051
+ const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
7052
+ memcpy(dst_ptr + id, src0_ptr, type_size);
7053
+
7054
+ id += type_size;
7055
+ }
7056
+ }
7057
+ id += rs * (ne01 - ir1);
7058
+ }
7059
+ }
7060
+ }
7061
+
7062
+ return;
7063
+ }
7064
+
7065
+ // dst counters
7066
+
7067
+ int64_t i10 = 0;
7068
+ int64_t i11 = 0;
7069
+ int64_t i12 = 0;
7070
+ int64_t i13 = 0;
7071
+
7072
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7073
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7074
+ i10 += ne00 * ir0;
7075
+ while (i10 >= ne0) {
7076
+ i10 -= ne0;
7077
+ if (++i11 == ne1) {
7078
+ i11 = 0;
7079
+ if (++i12 == ne2) {
7080
+ i12 = 0;
7081
+ if (++i13 == ne3) {
7082
+ i13 = 0;
7083
+ }
7084
+ }
7085
+ }
7086
+ }
7087
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7088
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7089
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7090
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7091
+
7092
+ memcpy(dst_ptr, src0_ptr, type_size);
7093
+
7094
+ if (++i10 == ne0) {
7095
+ i10 = 0;
7096
+ if (++i11 == ne1) {
7097
+ i11 = 0;
7098
+ if (++i12 == ne2) {
7099
+ i12 = 0;
7100
+ if (++i13 == ne3) {
7101
+ i13 = 0;
7102
+ }
7103
+ }
7104
+ }
7105
+ }
7106
+ }
7107
+ }
7108
+ i10 += ne00 * (ne01 - ir1);
7109
+ while (i10 >= ne0) {
7110
+ i10 -= ne0;
7111
+ if (++i11 == ne1) {
7112
+ i11 = 0;
7113
+ if (++i12 == ne2) {
7114
+ i12 = 0;
7115
+ if (++i13 == ne3) {
7116
+ i13 = 0;
7117
+ }
7118
+ }
7119
+ }
7120
+ }
7121
+ }
7122
+ }
7123
+ }
7124
+
7125
+ static void wsp_ggml_compute_forward_dup(
7126
+ const struct wsp_ggml_compute_params * params,
7127
+ const struct wsp_ggml_tensor * src0,
7128
+ struct wsp_ggml_tensor * dst) {
7129
+ if (src0->type == dst->type) {
7130
+ wsp_ggml_compute_forward_dup_bytes(params, src0, dst);
7131
+ return;
7132
+ }
7133
+
6947
7134
  switch (src0->type) {
6948
7135
  case WSP_GGML_TYPE_F16:
6949
7136
  {
@@ -7280,6 +7467,8 @@ static void wsp_ggml_compute_forward_add(
7280
7467
  case WSP_GGML_TYPE_Q4_K:
7281
7468
  case WSP_GGML_TYPE_Q5_K:
7282
7469
  case WSP_GGML_TYPE_Q6_K:
7470
+ case WSP_GGML_TYPE_IQ2_XXS:
7471
+ case WSP_GGML_TYPE_IQ2_XS:
7283
7472
  {
7284
7473
  wsp_ggml_compute_forward_add_q_f32(params, src0, src1, dst);
7285
7474
  } break;
@@ -7544,6 +7733,8 @@ static void wsp_ggml_compute_forward_add1(
7544
7733
  case WSP_GGML_TYPE_Q4_K:
7545
7734
  case WSP_GGML_TYPE_Q5_K:
7546
7735
  case WSP_GGML_TYPE_Q6_K:
7736
+ case WSP_GGML_TYPE_IQ2_XXS:
7737
+ case WSP_GGML_TYPE_IQ2_XS:
7547
7738
  {
7548
7739
  wsp_ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
7549
7740
  } break;
@@ -7658,6 +7849,8 @@ static void wsp_ggml_compute_forward_acc(
7658
7849
  case WSP_GGML_TYPE_Q4_K:
7659
7850
  case WSP_GGML_TYPE_Q5_K:
7660
7851
  case WSP_GGML_TYPE_Q6_K:
7852
+ case WSP_GGML_TYPE_IQ2_XXS:
7853
+ case WSP_GGML_TYPE_IQ2_XS:
7661
7854
  default:
7662
7855
  {
7663
7856
  WSP_GGML_ASSERT(false);
@@ -7759,10 +7952,10 @@ static void wsp_ggml_compute_forward_mul_f32(
7759
7952
  const int ith = params->ith;
7760
7953
  const int nth = params->nth;
7761
7954
 
7762
- // TODO: OpenCL kernel support broadcast
7763
7955
  #ifdef WSP_GGML_USE_CLBLAST
7764
7956
  if (src1->backend == WSP_GGML_BACKEND_GPU) {
7765
- WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1));
7957
+ // TODO: OpenCL kernel support full broadcast
7958
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(src1, src0));
7766
7959
  if (ith == 0) {
7767
7960
  wsp_ggml_cl_mul(src0, src1, dst);
7768
7961
  }
@@ -8402,10 +8595,12 @@ static void wsp_ggml_compute_forward_repeat(
8402
8595
  struct wsp_ggml_tensor * dst) {
8403
8596
  switch (src0->type) {
8404
8597
  case WSP_GGML_TYPE_F16:
8598
+ case WSP_GGML_TYPE_I16:
8405
8599
  {
8406
8600
  wsp_ggml_compute_forward_repeat_f16(params, src0, dst);
8407
8601
  } break;
8408
8602
  case WSP_GGML_TYPE_F32:
8603
+ case WSP_GGML_TYPE_I32:
8409
8604
  {
8410
8605
  wsp_ggml_compute_forward_repeat_f32(params, src0, dst);
8411
8606
  } break;
@@ -8548,6 +8743,7 @@ static void wsp_ggml_compute_forward_concat(
8548
8743
  struct wsp_ggml_tensor* dst) {
8549
8744
  switch (src0->type) {
8550
8745
  case WSP_GGML_TYPE_F32:
8746
+ case WSP_GGML_TYPE_I32:
8551
8747
  {
8552
8748
  wsp_ggml_compute_forward_concat_f32(params, src0, src1, dst);
8553
8749
  } break;
@@ -9159,6 +9355,8 @@ static void wsp_ggml_compute_forward_norm_f32(
9159
9355
  float eps;
9160
9356
  memcpy(&eps, dst->op_params, sizeof(float));
9161
9357
 
9358
+ WSP_GGML_ASSERT(eps > 0.0f);
9359
+
9162
9360
  // TODO: optimize
9163
9361
  for (int64_t i03 = 0; i03 < ne03; i03++) {
9164
9362
  for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -9228,6 +9426,8 @@ static void wsp_ggml_compute_forward_rms_norm_f32(
9228
9426
  float eps;
9229
9427
  memcpy(&eps, dst->op_params, sizeof(float));
9230
9428
 
9429
+ WSP_GGML_ASSERT(eps > 0.0f);
9430
+
9231
9431
  // TODO: optimize
9232
9432
  for (int64_t i03 = 0; i03 < ne03; i03++) {
9233
9433
  for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -9541,10 +9741,10 @@ static void wsp_ggml_compute_forward_group_norm(
9541
9741
  #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
9542
9742
  // helper function to determine if it is better to use BLAS or not
9543
9743
  // for large matrices, BLAS is faster
9544
- static bool wsp_ggml_compute_forward_mul_mat_use_blas(
9545
- const struct wsp_ggml_tensor * src0,
9546
- const struct wsp_ggml_tensor * src1,
9547
- struct wsp_ggml_tensor * dst) {
9744
+ static bool wsp_ggml_compute_forward_mul_mat_use_blas(struct wsp_ggml_tensor * dst) {
9745
+ const struct wsp_ggml_tensor * src0 = dst->src[0];
9746
+ const struct wsp_ggml_tensor * src1 = dst->src[1];
9747
+
9548
9748
  //const int64_t ne00 = src0->ne[0];
9549
9749
  //const int64_t ne01 = src0->ne[1];
9550
9750
 
@@ -9571,16 +9771,11 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas(
9571
9771
  }
9572
9772
  #endif
9573
9773
 
9574
- // off1 = offset in i11 and i1
9575
- // cne1 = ne11 and ne1
9576
- // in a normal matrix multiplication, off1 = 0 and cne1 = ne1
9577
- // during WSP_GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
9578
9774
  static void wsp_ggml_compute_forward_mul_mat(
9579
9775
  const struct wsp_ggml_compute_params * params,
9580
9776
  const struct wsp_ggml_tensor * src0,
9581
9777
  const struct wsp_ggml_tensor * src1,
9582
- struct wsp_ggml_tensor * dst,
9583
- int64_t off1, int64_t cne1) {
9778
+ struct wsp_ggml_tensor * dst) {
9584
9779
  int64_t t0 = wsp_ggml_perf_time_us();
9585
9780
  UNUSED(t0);
9586
9781
 
@@ -9629,7 +9824,7 @@ static void wsp_ggml_compute_forward_mul_mat(
9629
9824
  #endif
9630
9825
 
9631
9826
  #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
9632
- if (wsp_ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
9827
+ if (wsp_ggml_compute_forward_mul_mat_use_blas(dst)) {
9633
9828
  if (params->ith != 0) {
9634
9829
  return;
9635
9830
  }
@@ -9648,9 +9843,9 @@ static void wsp_ggml_compute_forward_mul_mat(
9648
9843
  const int64_t i03 = i13/r3;
9649
9844
  const int64_t i02 = i12/r2;
9650
9845
 
9651
- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9652
- const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
9653
- float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
9846
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9847
+ const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
9848
+ float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
9654
9849
 
9655
9850
  if (type != WSP_GGML_TYPE_F32) {
9656
9851
  float * const wdata = params->wdata;
@@ -9667,7 +9862,7 @@ static void wsp_ggml_compute_forward_mul_mat(
9667
9862
  }
9668
9863
 
9669
9864
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
9670
- cne1, ne01, ne10,
9865
+ ne1, ne01, ne10,
9671
9866
  1.0f, y, ne10,
9672
9867
  x, ne00,
9673
9868
  0.0f, d, ne01);
@@ -9683,10 +9878,10 @@ static void wsp_ggml_compute_forward_mul_mat(
9683
9878
  if (params->type == WSP_GGML_TASK_INIT) {
9684
9879
  if (src1->type != vec_dot_type) {
9685
9880
  char * wdata = params->wdata;
9686
- const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type);
9881
+ const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10);
9687
9882
 
9688
9883
  assert(params->wsize >= ne11*ne12*ne13*row_size);
9689
- assert(src1->type == WSP_GGML_TYPE_F32);
9884
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
9690
9885
 
9691
9886
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
9692
9887
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -9706,10 +9901,10 @@ static void wsp_ggml_compute_forward_mul_mat(
9706
9901
  }
9707
9902
 
9708
9903
  const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
9709
- const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type);
9904
+ const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10);
9710
9905
 
9711
- const int64_t nr0 = ne01; // src0 rows
9712
- const int64_t nr1 = cne1*ne12*ne13; // src1 rows
9906
+ const int64_t nr0 = ne01; // src0 rows
9907
+ const int64_t nr1 = ne1*ne12*ne13; // src1 rows
9713
9908
 
9714
9909
  //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
9715
9910
 
@@ -9751,9 +9946,9 @@ static void wsp_ggml_compute_forward_mul_mat(
9751
9946
  for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9752
9947
  for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9753
9948
  for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9754
- const int64_t i13 = (ir1/(ne12*cne1));
9755
- const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
9756
- const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
9949
+ const int64_t i13 = (ir1/(ne12*ne1));
9950
+ const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
9951
+ const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
9757
9952
 
9758
9953
  // broadcast src0 into src1
9759
9954
  const int64_t i03 = i13/r3;
@@ -9793,28 +9988,191 @@ static void wsp_ggml_compute_forward_mul_mat(
9793
9988
 
9794
9989
  static void wsp_ggml_compute_forward_mul_mat_id(
9795
9990
  const struct wsp_ggml_compute_params * params,
9796
- const struct wsp_ggml_tensor * src0,
9991
+ const struct wsp_ggml_tensor * ids,
9797
9992
  const struct wsp_ggml_tensor * src1,
9798
9993
  struct wsp_ggml_tensor * dst) {
9799
9994
 
9800
- if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
9801
- // during WSP_GGML_TASK_INIT the entire src1 is converted to vec_dot_type
9802
- wsp_ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
9803
- return;
9804
- }
9995
+ const struct wsp_ggml_tensor * src0 = dst->src[2]; // only for WSP_GGML_TENSOR_BINARY_OP_LOCALS
9996
+
9997
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
9998
+
9999
+ const int ith = params->ith;
10000
+ const int nth = params->nth;
10001
+
10002
+ const enum wsp_ggml_type type = src0->type;
10003
+
10004
+ const bool src1_cont = wsp_ggml_is_contiguous(src1);
10005
+
10006
+ wsp_ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
10007
+ enum wsp_ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
10008
+ wsp_ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
10009
+
10010
+ WSP_GGML_ASSERT(ne0 == ne01);
10011
+ WSP_GGML_ASSERT(ne1 == ne11);
10012
+ WSP_GGML_ASSERT(ne2 == ne12);
10013
+ WSP_GGML_ASSERT(ne3 == ne13);
9805
10014
 
9806
- const struct wsp_ggml_tensor * ids = src0;
10015
+ // we don't support permuted src0 or src1
10016
+ WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
10017
+ WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(src1->type));
10018
+
10019
+ // dst cannot be transposed or permuted
10020
+ WSP_GGML_ASSERT(nb0 == sizeof(float));
10021
+ WSP_GGML_ASSERT(nb0 <= nb1);
10022
+ WSP_GGML_ASSERT(nb1 <= nb2);
10023
+ WSP_GGML_ASSERT(nb2 <= nb3);
10024
+
10025
+ // broadcast factors
10026
+ const int64_t r2 = ne12/ne02;
10027
+ const int64_t r3 = ne13/ne03;
10028
+
10029
+ // row groups
9807
10030
  const int id = wsp_ggml_get_op_params_i32(dst, 0);
9808
10031
  const int n_as = wsp_ggml_get_op_params_i32(dst, 1);
9809
10032
 
9810
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
9811
- const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
10033
+ char * wdata_src1_end = (src1->type == vec_dot_type) ?
10034
+ (char *) params->wdata :
10035
+ (char *) params->wdata + WSP_GGML_PAD(wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1)), sizeof(int64_t));
10036
+
10037
+ int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
10038
+ int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
10039
+
10040
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
10041
+
10042
+ if (params->type == WSP_GGML_TASK_INIT) {
10043
+ char * wdata = params->wdata;
10044
+ if (src1->type != vec_dot_type) {
10045
+ const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10);
10046
+
10047
+ assert(params->wsize >= ne11*ne12*ne13*row_size);
10048
+ assert(src1->type == WSP_GGML_TYPE_F32);
10049
+
10050
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
10051
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
10052
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
10053
+ from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
10054
+ wdata += row_size;
10055
+ }
10056
+ }
10057
+ }
10058
+ }
10059
+
10060
+ // initialize matrix_row_counts
10061
+ WSP_GGML_ASSERT(wdata == wdata_src1_end);
10062
+ memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
10063
+
10064
+ // group rows by src0 matrix
10065
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
10066
+ const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
9812
10067
 
9813
- WSP_GGML_ASSERT(row_id >= 0 && row_id < n_as);
10068
+ WSP_GGML_ASSERT(row_id >= 0 && row_id < n_as);
10069
+ MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
10070
+ matrix_row_counts[row_id] += 1;
10071
+ }
10072
+
10073
+ return;
10074
+ }
9814
10075
 
9815
- const struct wsp_ggml_tensor * src0_row = dst->src[row_id + 2];
9816
- wsp_ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
10076
+ if (params->type == WSP_GGML_TASK_FINALIZE) {
10077
+ return;
9817
10078
  }
10079
+
10080
+ // compute each matrix multiplication in sequence
10081
+ for (int cur_a = 0; cur_a < n_as; ++cur_a) {
10082
+ const int64_t cne1 = matrix_row_counts[cur_a];
10083
+
10084
+ if (cne1 == 0) {
10085
+ continue;
10086
+ }
10087
+
10088
+ const struct wsp_ggml_tensor * src0_cur = dst->src[cur_a + 2];
10089
+
10090
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10091
+ const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10);
10092
+
10093
+ const int64_t nr0 = ne01; // src0 rows
10094
+ const int64_t nr1 = cne1*ne12*ne13; // src1 rows
10095
+
10096
+ //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
10097
+
10098
+ // distribute the thread work across the inner or outer loop based on which one is larger
10099
+
10100
+ const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
10101
+ const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
10102
+
10103
+ const int64_t ith0 = ith % nth0;
10104
+ const int64_t ith1 = ith / nth0;
10105
+
10106
+ const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
10107
+ const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
10108
+
10109
+ const int64_t ir010 = dr0*ith0;
10110
+ const int64_t ir011 = MIN(ir010 + dr0, nr0);
10111
+
10112
+ const int64_t ir110 = dr1*ith1;
10113
+ const int64_t ir111 = MIN(ir110 + dr1, nr1);
10114
+
10115
+ //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
10116
+
10117
+ // threads with no work simply yield (not sure if it helps)
10118
+ if (ir010 >= ir011 || ir110 >= ir111) {
10119
+ sched_yield();
10120
+ continue;
10121
+ }
10122
+
10123
+ assert(ne12 % ne02 == 0);
10124
+ assert(ne13 % ne03 == 0);
10125
+
10126
+ // block-tiling attempt
10127
+ const int64_t blck_0 = 16;
10128
+ const int64_t blck_1 = 16;
10129
+
10130
+ // attempt to reduce false-sharing (does not seem to make a difference)
10131
+ float tmp[16];
10132
+
10133
+ for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
10134
+ for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
10135
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
10136
+ const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
10137
+ const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
10138
+ const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
10139
+ const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
10140
+
10141
+ // broadcast src0 into src1
10142
+ const int64_t i03 = i13/r3;
10143
+ const int64_t i02 = i12/r2;
10144
+
10145
+ const int64_t i1 = i11;
10146
+ const int64_t i2 = i12;
10147
+ const int64_t i3 = i13;
10148
+
10149
+ const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03);
10150
+
10151
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
10152
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
10153
+ // the original src1 data pointer, so we should index using the indices directly
10154
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
10155
+ const char * src1_col = (const char *) wdata +
10156
+ (src1_cont || src1->type != vec_dot_type
10157
+ ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
10158
+ : (i11*nb11 + i12*nb12 + i13*nb13));
10159
+
10160
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10161
+
10162
+ //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10163
+ // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
10164
+ //}
10165
+
10166
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10167
+ vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
10168
+ }
10169
+ memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
10170
+ }
10171
+ }
10172
+ }
10173
+ }
10174
+
10175
+ #undef MMID_MATRIX_ROW
9818
10176
  }
9819
10177
 
9820
10178
  // wsp_ggml_compute_forward_out_prod
@@ -10134,6 +10492,8 @@ static void wsp_ggml_compute_forward_out_prod(
10134
10492
  case WSP_GGML_TYPE_Q4_K:
10135
10493
  case WSP_GGML_TYPE_Q5_K:
10136
10494
  case WSP_GGML_TYPE_Q6_K:
10495
+ case WSP_GGML_TYPE_IQ2_XXS:
10496
+ case WSP_GGML_TYPE_IQ2_XS:
10137
10497
  {
10138
10498
  wsp_ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10139
10499
  } break;
@@ -10158,19 +10518,18 @@ static void wsp_ggml_compute_forward_out_prod(
10158
10518
  static void wsp_ggml_compute_forward_scale_f32(
10159
10519
  const struct wsp_ggml_compute_params * params,
10160
10520
  const struct wsp_ggml_tensor * src0,
10161
- const struct wsp_ggml_tensor * src1,
10162
10521
  struct wsp_ggml_tensor * dst) {
10163
10522
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
10164
10523
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst));
10165
10524
  WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst));
10166
- WSP_GGML_ASSERT(wsp_ggml_is_scalar(src1));
10167
10525
 
10168
10526
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
10169
10527
  return;
10170
10528
  }
10171
10529
 
10172
10530
  // scale factor
10173
- const float v = *(float *) src1->data;
10531
+ float v;
10532
+ memcpy(&v, dst->op_params, sizeof(float));
10174
10533
 
10175
10534
  const int ith = params->ith;
10176
10535
  const int nth = params->nth;
@@ -10201,12 +10560,11 @@ static void wsp_ggml_compute_forward_scale_f32(
10201
10560
  static void wsp_ggml_compute_forward_scale(
10202
10561
  const struct wsp_ggml_compute_params * params,
10203
10562
  const struct wsp_ggml_tensor * src0,
10204
- const struct wsp_ggml_tensor * src1,
10205
10563
  struct wsp_ggml_tensor * dst) {
10206
10564
  switch (src0->type) {
10207
10565
  case WSP_GGML_TYPE_F32:
10208
10566
  {
10209
- wsp_ggml_compute_forward_scale_f32(params, src0, src1, dst);
10567
+ wsp_ggml_compute_forward_scale_f32(params, src0, dst);
10210
10568
  } break;
10211
10569
  default:
10212
10570
  {
@@ -10310,6 +10668,8 @@ static void wsp_ggml_compute_forward_set(
10310
10668
  case WSP_GGML_TYPE_Q4_K:
10311
10669
  case WSP_GGML_TYPE_Q5_K:
10312
10670
  case WSP_GGML_TYPE_Q6_K:
10671
+ case WSP_GGML_TYPE_IQ2_XXS:
10672
+ case WSP_GGML_TYPE_IQ2_XS:
10313
10673
  default:
10314
10674
  {
10315
10675
  WSP_GGML_ASSERT(false);
@@ -10504,6 +10864,8 @@ static void wsp_ggml_compute_forward_get_rows(
10504
10864
  case WSP_GGML_TYPE_Q4_K:
10505
10865
  case WSP_GGML_TYPE_Q5_K:
10506
10866
  case WSP_GGML_TYPE_Q6_K:
10867
+ case WSP_GGML_TYPE_IQ2_XXS:
10868
+ case WSP_GGML_TYPE_IQ2_XS:
10507
10869
  {
10508
10870
  wsp_ggml_compute_forward_get_rows_q(params, src0, src1, dst);
10509
10871
  } break;
@@ -10512,6 +10874,7 @@ static void wsp_ggml_compute_forward_get_rows(
10512
10874
  wsp_ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
10513
10875
  } break;
10514
10876
  case WSP_GGML_TYPE_F32:
10877
+ case WSP_GGML_TYPE_I32:
10515
10878
  {
10516
10879
  wsp_ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
10517
10880
  } break;
@@ -11139,6 +11502,8 @@ static void wsp_ggml_compute_forward_alibi(
11139
11502
  case WSP_GGML_TYPE_Q4_K:
11140
11503
  case WSP_GGML_TYPE_Q5_K:
11141
11504
  case WSP_GGML_TYPE_Q6_K:
11505
+ case WSP_GGML_TYPE_IQ2_XXS:
11506
+ case WSP_GGML_TYPE_IQ2_XS:
11142
11507
  case WSP_GGML_TYPE_Q8_K:
11143
11508
  case WSP_GGML_TYPE_I8:
11144
11509
  case WSP_GGML_TYPE_I16:
@@ -11213,6 +11578,8 @@ static void wsp_ggml_compute_forward_clamp(
11213
11578
  case WSP_GGML_TYPE_Q4_K:
11214
11579
  case WSP_GGML_TYPE_Q5_K:
11215
11580
  case WSP_GGML_TYPE_Q6_K:
11581
+ case WSP_GGML_TYPE_IQ2_XXS:
11582
+ case WSP_GGML_TYPE_IQ2_XS:
11216
11583
  case WSP_GGML_TYPE_Q8_K:
11217
11584
  case WSP_GGML_TYPE_I8:
11218
11585
  case WSP_GGML_TYPE_I16:
@@ -11257,7 +11624,22 @@ static float wsp_ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot
11257
11624
  return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
11258
11625
  }
11259
11626
 
11260
- void wsp_ggml_rope_yarn_corr_dims(
11627
+ static void wsp_ggml_rope_cache_init(
11628
+ float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
11629
+ float * cache, float sin_sign, float theta_scale
11630
+ ) {
11631
+ float theta = theta_base;
11632
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11633
+ rope_yarn(
11634
+ theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
11635
+ );
11636
+ cache[i0 + 1] *= sin_sign;
11637
+
11638
+ theta *= theta_scale;
11639
+ }
11640
+ }
11641
+
11642
+ WSP_GGML_CALL void wsp_ggml_rope_yarn_corr_dims(
11261
11643
  int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
11262
11644
  ) {
11263
11645
  // start and end correction dims
@@ -11339,6 +11721,12 @@ static void wsp_ggml_compute_forward_rope_f32(
11339
11721
  for (int64_t i3 = 0; i3 < ne3; i3++) {
11340
11722
  for (int64_t i2 = 0; i2 < ne2; i2++) {
11341
11723
  const int64_t p = pos[i2];
11724
+
11725
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
11726
+ if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
11727
+ wsp_ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11728
+ }
11729
+
11342
11730
  for (int64_t i1 = 0; i1 < ne1; i1++) {
11343
11731
  if (ir++ < ir0) continue;
11344
11732
  if (ir > ir1) break;
@@ -11372,18 +11760,13 @@ static void wsp_ggml_compute_forward_rope_f32(
11372
11760
  }
11373
11761
  } else if (!is_neox) {
11374
11762
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11375
- float cos_theta, sin_theta;
11376
- rope_yarn(
11377
- theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11378
- );
11379
- sin_theta *= sin_sign;
11763
+ const float cos_theta = cache[i0 + 0];
11764
+ const float sin_theta = cache[i0 + 1];
11380
11765
 
11381
11766
  // zeta scaling for xPos only:
11382
11767
  float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
11383
11768
  if (xpos_down) zeta = 1.0f / zeta;
11384
11769
 
11385
- theta_base *= theta_scale;
11386
-
11387
11770
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11388
11771
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11389
11772
 
@@ -11395,10 +11778,13 @@ static void wsp_ggml_compute_forward_rope_f32(
11395
11778
  }
11396
11779
  } else {
11397
11780
  // TODO: this might be wrong for ne0 != n_dims - need double check
11398
- // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
11781
+ // it seems we have to rope just the first n_dims elements and do nothing with the rest
11782
+ // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
11399
11783
  theta_base *= freq_scale;
11400
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
11401
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
11784
+ for (int64_t ic = 0; ic < ne0; ic += 2) {
11785
+ if (ic < n_dims) {
11786
+ const int64_t ib = 0;
11787
+
11402
11788
  // simplified from `(ib * n_dims + ic) * inv_ndims`
11403
11789
  float cur_rot = inv_ndims * ic - ib;
11404
11790
 
@@ -11421,6 +11807,14 @@ static void wsp_ggml_compute_forward_rope_f32(
11421
11807
 
11422
11808
  dst_data[0] = x0*cos_theta - x1*sin_theta;
11423
11809
  dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
11810
+ } else {
11811
+ const int64_t i0 = ic;
11812
+
11813
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11814
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11815
+
11816
+ dst_data[0] = src[0];
11817
+ dst_data[1] = src[1];
11424
11818
  }
11425
11819
  }
11426
11820
  }
@@ -11496,6 +11890,12 @@ static void wsp_ggml_compute_forward_rope_f16(
11496
11890
  for (int64_t i3 = 0; i3 < ne3; i3++) {
11497
11891
  for (int64_t i2 = 0; i2 < ne2; i2++) {
11498
11892
  const int64_t p = pos[i2];
11893
+
11894
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
11895
+ if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
11896
+ wsp_ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11897
+ }
11898
+
11499
11899
  for (int64_t i1 = 0; i1 < ne1; i1++) {
11500
11900
  if (ir++ < ir0) continue;
11501
11901
  if (ir > ir1) break;
@@ -11529,13 +11929,8 @@ static void wsp_ggml_compute_forward_rope_f16(
11529
11929
  }
11530
11930
  } else if (!is_neox) {
11531
11931
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11532
- float cos_theta, sin_theta;
11533
- rope_yarn(
11534
- theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11535
- );
11536
- sin_theta *= sin_sign;
11537
-
11538
- theta_base *= theta_scale;
11932
+ const float cos_theta = cache[i0 + 0];
11933
+ const float sin_theta = cache[i0 + 1];
11539
11934
 
11540
11935
  const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11541
11936
  wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -11548,10 +11943,13 @@ static void wsp_ggml_compute_forward_rope_f16(
11548
11943
  }
11549
11944
  } else {
11550
11945
  // TODO: this might be wrong for ne0 != n_dims - need double check
11551
- // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
11946
+ // it seems we have to rope just the first n_dims elements and do nothing with the rest
11947
+ // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
11552
11948
  theta_base *= freq_scale;
11553
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
11554
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
11949
+ for (int64_t ic = 0; ic < ne0; ic += 2) {
11950
+ if (ic < n_dims) {
11951
+ const int64_t ib = 0;
11952
+
11555
11953
  // simplified from `(ib * n_dims + ic) * inv_ndims`
11556
11954
  float cur_rot = inv_ndims * ic - ib;
11557
11955
 
@@ -11574,6 +11972,14 @@ static void wsp_ggml_compute_forward_rope_f16(
11574
11972
 
11575
11973
  dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11576
11974
  dst_data[n_dims/2] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11975
+ } else {
11976
+ const int64_t i0 = ic;
11977
+
11978
+ const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11979
+ wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11980
+
11981
+ dst_data[0] = src[0];
11982
+ dst_data[1] = src[1];
11577
11983
  }
11578
11984
  }
11579
11985
  }
@@ -14182,7 +14588,7 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14182
14588
  } break;
14183
14589
  case WSP_GGML_OP_MUL_MAT:
14184
14590
  {
14185
- wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
14591
+ wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14186
14592
  } break;
14187
14593
  case WSP_GGML_OP_MUL_MAT_ID:
14188
14594
  {
@@ -14194,7 +14600,7 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14194
14600
  } break;
14195
14601
  case WSP_GGML_OP_SCALE:
14196
14602
  {
14197
- wsp_ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor);
14603
+ wsp_ggml_compute_forward_scale(params, tensor->src[0], tensor);
14198
14604
  } break;
14199
14605
  case WSP_GGML_OP_SET:
14200
14606
  {
@@ -14489,7 +14895,7 @@ size_t wsp_ggml_hash_find_or_insert(struct wsp_ggml_hash_set hash_set, struct ws
14489
14895
  return i;
14490
14896
  }
14491
14897
 
14492
- static struct wsp_ggml_hash_set wsp_ggml_hash_set_new(size_t size) {
14898
+ struct wsp_ggml_hash_set wsp_ggml_hash_set_new(size_t size) {
14493
14899
  size = wsp_ggml_hash_size(size);
14494
14900
  struct wsp_ggml_hash_set result;
14495
14901
  result.size = size;
@@ -14558,7 +14964,7 @@ static struct wsp_ggml_tensor * wsp_ggml_recompute_graph_node(
14558
14964
  return replacements->vals[i];
14559
14965
  }
14560
14966
 
14561
- struct wsp_ggml_tensor * clone = wsp_ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
14967
+ struct wsp_ggml_tensor * clone = wsp_ggml_new_tensor(ctx, node->type, WSP_GGML_MAX_DIMS, node->ne);
14562
14968
 
14563
14969
  // insert clone into replacements
14564
14970
  WSP_GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite
@@ -14650,7 +15056,7 @@ static struct wsp_ggml_tensor * wsp_ggml_add_or_set(struct wsp_ggml_context * ct
14650
15056
 
14651
15057
  static struct wsp_ggml_tensor * wsp_ggml_acc_or_set(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct wsp_ggml_hash_set zero_table) {
14652
15058
  if (wsp_ggml_hash_contains(zero_table, a)) {
14653
- struct wsp_ggml_tensor * a_zero = wsp_ggml_scale(ctx, a, wsp_ggml_new_f32(ctx, 0));
15059
+ struct wsp_ggml_tensor * a_zero = wsp_ggml_scale(ctx, a, 0.0f);
14654
15060
  return wsp_ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
14655
15061
  } else {
14656
15062
  return wsp_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
@@ -14786,7 +15192,7 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
14786
15192
  src0->grad,
14787
15193
  wsp_ggml_scale(ctx,
14788
15194
  wsp_ggml_mul(ctx, src0, tensor->grad),
14789
- wsp_ggml_new_f32(ctx, 2.0f)),
15195
+ 2.0f),
14790
15196
  zero_table);
14791
15197
  }
14792
15198
  } break;
@@ -14800,7 +15206,7 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
14800
15206
  wsp_ggml_div(ctx,
14801
15207
  tensor->grad,
14802
15208
  tensor),
14803
- wsp_ggml_new_f32(ctx, 0.5f)),
15209
+ 0.5f),
14804
15210
  zero_table);
14805
15211
  }
14806
15212
  } break;
@@ -14966,17 +15372,13 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
14966
15372
  {
14967
15373
  // necessary for llama
14968
15374
  if (src0->grad) {
15375
+ float s;
15376
+ memcpy(&s, tensor->op_params, sizeof(float));
15377
+
14969
15378
  src0->grad =
14970
15379
  wsp_ggml_add_or_set(ctx,
14971
15380
  src0->grad,
14972
- wsp_ggml_scale_impl(ctx, tensor->grad, src1, false),
14973
- zero_table);
14974
- }
14975
- if (src1->grad) {
14976
- src1->grad =
14977
- wsp_ggml_add_or_set(ctx,
14978
- src1->grad,
14979
- wsp_ggml_sum(ctx, wsp_ggml_mul_impl(ctx, tensor->grad, src0, false)),
15381
+ wsp_ggml_scale_impl(ctx, tensor->grad, s, false),
14980
15382
  zero_table);
14981
15383
  }
14982
15384
  } break;
@@ -15154,6 +15556,8 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15154
15556
  const int n_past = ((int32_t *) tensor->op_params)[0];
15155
15557
  src0->grad =
15156
15558
  wsp_ggml_add_or_set(ctx, src0->grad,
15559
+ /* wsp_ggml_diag_mask_inf_impl() shouldn't be here */
15560
+ /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
15157
15561
  wsp_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
15158
15562
  zero_table);
15159
15563
  }
@@ -15961,28 +16365,9 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
15961
16365
 
15962
16366
  //n_tasks = MIN(n_threads, MAX(1, nr0/128));
15963
16367
  //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
15964
-
15965
- #if defined(WSP_GGML_USE_CUBLAS)
15966
- if (wsp_ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
15967
- n_tasks = 1; // TODO: this actually is doing nothing
15968
- // the threads are still spinning
15969
- }
15970
- #elif defined(WSP_GGML_USE_CLBLAST)
15971
- if (wsp_ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
15972
- n_tasks = 1; // TODO: this actually is doing nothing
15973
- // the threads are still spinning
15974
- }
15975
- #endif
15976
- #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
15977
- if (wsp_ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
15978
- n_tasks = 1; // TODO: this actually is doing nothing
15979
- // the threads are still spinning
15980
- }
15981
- #endif
15982
16368
  } break;
15983
16369
  case WSP_GGML_OP_MUL_MAT_ID:
15984
16370
  {
15985
- // FIXME: blas
15986
16371
  n_tasks = n_threads;
15987
16372
  } break;
15988
16373
  case WSP_GGML_OP_OUT_PROD:
@@ -16152,6 +16537,7 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) {
16152
16537
  state->shared->node_n += 1;
16153
16538
  return (thread_ret_t) WSP_GGML_EXIT_ABORTED;
16154
16539
  }
16540
+
16155
16541
  if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
16156
16542
  // all other threads are finished and spinning
16157
16543
  // do finalize and init here so we don't have synchronize again
@@ -16217,14 +16603,18 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) {
16217
16603
  } else {
16218
16604
  // wait for other threads to finish
16219
16605
  const int last = node_n;
16606
+
16607
+ const bool do_yield = last < 0 || cgraph->nodes[last]->op == WSP_GGML_OP_MUL_MAT;
16608
+
16220
16609
  while (true) {
16221
16610
  // TODO: this sched_yield can have significant impact on the performance - either positive or negative
16222
16611
  // depending on the workload and the operating system.
16223
16612
  // since it is not clear what is the best approach, it should potentially become user-configurable
16224
16613
  // ref: https://github.com/ggerganov/ggml/issues/291
16225
- #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
16226
- sched_yield();
16227
- #endif
16614
+ // UPD: adding the do_yield flag seems to resolve the issue universally
16615
+ if (do_yield) {
16616
+ sched_yield();
16617
+ }
16228
16618
 
16229
16619
  node_n = atomic_load(&state->shared->node_n);
16230
16620
  if (node_n != last) break;
@@ -16254,7 +16644,7 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) {
16254
16644
  return WSP_GGML_EXIT_SUCCESS;
16255
16645
  }
16256
16646
 
16257
- struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n_threads) {
16647
+ struct wsp_ggml_cplan wsp_ggml_graph_plan(const struct wsp_ggml_cgraph * cgraph, int n_threads) {
16258
16648
  if (n_threads <= 0) {
16259
16649
  n_threads = WSP_GGML_DEFAULT_N_THREADS;
16260
16650
  }
@@ -16303,7 +16693,7 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16303
16693
  } else
16304
16694
  #endif
16305
16695
  #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
16306
- if (wsp_ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
16696
+ if (wsp_ggml_compute_forward_mul_mat_use_blas(node)) {
16307
16697
  if (node->src[0]->type != WSP_GGML_TYPE_F32) {
16308
16698
  // here we need memory just for single 2D matrix from src0
16309
16699
  cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]);
@@ -16311,25 +16701,22 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16311
16701
  } else
16312
16702
  #endif
16313
16703
  if (node->src[1]->type != vec_dot_type) {
16314
- cur = wsp_ggml_type_size(vec_dot_type)*wsp_ggml_nelements(node->src[1])/wsp_ggml_blck_size(vec_dot_type);
16704
+ cur = wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(node->src[1]));
16315
16705
  }
16316
16706
  } break;
16317
16707
  case WSP_GGML_OP_MUL_MAT_ID:
16318
16708
  {
16319
- const struct wsp_ggml_tensor * a = node->src[2];
16320
- const struct wsp_ggml_tensor * b = node->src[1];
16321
- const enum wsp_ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
16322
- #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
16323
- if (wsp_ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
16324
- if (a->type != WSP_GGML_TYPE_F32) {
16325
- // here we need memory just for single 2D matrix from src0
16326
- cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
16327
- }
16328
- } else
16329
- #endif
16330
- if (b->type != vec_dot_type) {
16331
- cur = wsp_ggml_type_size(vec_dot_type)*wsp_ggml_nelements(b)/wsp_ggml_blck_size(vec_dot_type);
16709
+ cur = 0;
16710
+ const struct wsp_ggml_tensor * src0 = node->src[2];
16711
+ const struct wsp_ggml_tensor * src1 = node->src[1];
16712
+ const enum wsp_ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
16713
+ if (src1->type != vec_dot_type) {
16714
+ cur += wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1));
16332
16715
  }
16716
+ const int n_as = wsp_ggml_get_op_params_i32(node, 1);
16717
+ cur += WSP_GGML_PAD(cur, sizeof(int64_t)); // align
16718
+ cur += n_as * sizeof(int64_t); // matrix_row_counts
16719
+ cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
16333
16720
  } break;
16334
16721
  case WSP_GGML_OP_OUT_PROD:
16335
16722
  {
@@ -16338,6 +16725,7 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16338
16725
  }
16339
16726
  } break;
16340
16727
  case WSP_GGML_OP_SOFT_MAX:
16728
+ case WSP_GGML_OP_ROPE:
16341
16729
  {
16342
16730
  cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
16343
16731
  } break;
@@ -16559,7 +16947,7 @@ static void wsp_ggml_graph_export_leaf(const struct wsp_ggml_tensor * tensor, FI
16559
16947
  fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
16560
16948
  wsp_ggml_type_name(tensor->type),
16561
16949
  wsp_ggml_op_name (tensor->op),
16562
- tensor->n_dims,
16950
+ wsp_ggml_n_dims(tensor),
16563
16951
  ne[0], ne[1], ne[2], ne[3],
16564
16952
  nb[0], nb[1], nb[2], nb[3],
16565
16953
  tensor->data,
@@ -16574,7 +16962,7 @@ static void wsp_ggml_graph_export_node(const struct wsp_ggml_tensor * tensor, co
16574
16962
  arg,
16575
16963
  wsp_ggml_type_name(tensor->type),
16576
16964
  wsp_ggml_op_name (tensor->op),
16577
- tensor->n_dims,
16965
+ wsp_ggml_n_dims(tensor),
16578
16966
  ne[0], ne[1], ne[2], ne[3],
16579
16967
  nb[0], nb[1], nb[2], nb[3],
16580
16968
  tensor->data,
@@ -16664,11 +17052,9 @@ void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * f
16664
17052
 
16665
17053
  const uint32_t type = tensor->type;
16666
17054
  const uint32_t op = tensor->op;
16667
- const uint32_t n_dims = tensor->n_dims;
16668
17055
 
16669
17056
  fwrite(&type, sizeof(uint32_t), 1, fout);
16670
17057
  fwrite(&op, sizeof(uint32_t), 1, fout);
16671
- fwrite(&n_dims, sizeof(uint32_t), 1, fout);
16672
17058
 
16673
17059
  for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) {
16674
17060
  const uint64_t ne = tensor->ne[j];
@@ -16698,11 +17084,9 @@ void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * f
16698
17084
 
16699
17085
  const uint32_t type = tensor->type;
16700
17086
  const uint32_t op = tensor->op;
16701
- const uint32_t n_dims = tensor->n_dims;
16702
17087
 
16703
17088
  fwrite(&type, sizeof(uint32_t), 1, fout);
16704
17089
  fwrite(&op, sizeof(uint32_t), 1, fout);
16705
- fwrite(&n_dims, sizeof(uint32_t), 1, fout);
16706
17090
 
16707
17091
  for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) {
16708
17092
  const uint64_t ne = tensor->ne[j];
@@ -16874,12 +17258,10 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg
16874
17258
  {
16875
17259
  uint32_t type;
16876
17260
  uint32_t op;
16877
- uint32_t n_dims;
16878
17261
 
16879
17262
  for (uint32_t i = 0; i < n_leafs; ++i) {
16880
17263
  type = *(const uint32_t *) ptr; ptr += sizeof(type);
16881
17264
  op = *(const uint32_t *) ptr; ptr += sizeof(op);
16882
- n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
16883
17265
 
16884
17266
  int64_t ne[WSP_GGML_MAX_DIMS];
16885
17267
  size_t nb[WSP_GGML_MAX_DIMS];
@@ -16895,7 +17277,7 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg
16895
17277
  nb[j] = nb_cur;
16896
17278
  }
16897
17279
 
16898
- struct wsp_ggml_tensor * tensor = wsp_ggml_new_tensor(*ctx_eval, (enum wsp_ggml_type) type, n_dims, ne);
17280
+ struct wsp_ggml_tensor * tensor = wsp_ggml_new_tensor(*ctx_eval, (enum wsp_ggml_type) type, WSP_GGML_MAX_DIMS, ne);
16899
17281
 
16900
17282
  tensor->op = (enum wsp_ggml_op) op;
16901
17283
 
@@ -16912,7 +17294,7 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg
16912
17294
 
16913
17295
  ptr += wsp_ggml_nbytes(tensor);
16914
17296
 
16915
- fprintf(stderr, "%s: loaded leaf %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, wsp_ggml_nbytes(tensor));
17297
+ fprintf(stderr, "%s: loaded leaf %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, wsp_ggml_nbytes(tensor));
16916
17298
  }
16917
17299
  }
16918
17300
 
@@ -16922,12 +17304,10 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg
16922
17304
  {
16923
17305
  uint32_t type;
16924
17306
  uint32_t op;
16925
- uint32_t n_dims;
16926
17307
 
16927
17308
  for (uint32_t i = 0; i < n_nodes; ++i) {
16928
17309
  type = *(const uint32_t *) ptr; ptr += sizeof(type);
16929
17310
  op = *(const uint32_t *) ptr; ptr += sizeof(op);
16930
- n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
16931
17311
 
16932
17312
  enum wsp_ggml_op eop = (enum wsp_ggml_op) op;
16933
17313
 
@@ -16998,7 +17378,7 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg
16998
17378
  } break;
16999
17379
  default:
17000
17380
  {
17001
- tensor = wsp_ggml_new_tensor(*ctx_eval, (enum wsp_ggml_type) type, n_dims, ne);
17381
+ tensor = wsp_ggml_new_tensor(*ctx_eval, (enum wsp_ggml_type) type, WSP_GGML_MAX_DIMS, ne);
17002
17382
 
17003
17383
  tensor->op = eop;
17004
17384
  } break;
@@ -17017,7 +17397,7 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg
17017
17397
 
17018
17398
  result->nodes[i] = tensor;
17019
17399
 
17020
- fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, wsp_ggml_nbytes(tensor));
17400
+ fprintf(stderr, "%s: loaded node %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, wsp_ggml_nbytes(tensor));
17021
17401
  }
17022
17402
  }
17023
17403
  }
@@ -17155,7 +17535,7 @@ void wsp_ggml_graph_dump_dot(const struct wsp_ggml_cgraph * gb, const struct wsp
17155
17535
  fprintf(fp, "(%s)|", wsp_ggml_type_name(node->type));
17156
17536
  }
17157
17537
 
17158
- if (node->n_dims == 2) {
17538
+ if (wsp_ggml_is_matrix(node)) {
17159
17539
  fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], wsp_ggml_op_symbol(node->op));
17160
17540
  } else {
17161
17541
  fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], wsp_ggml_op_symbol(node->op));
@@ -17284,9 +17664,9 @@ static void wsp_ggml_opt_acc_grad(int np, struct wsp_ggml_tensor * const ps[], f
17284
17664
  }
17285
17665
 
17286
17666
  //
17287
- // ADAM
17667
+ // Using AdamW - ref: https://arxiv.org/pdf/1711.05101v3.pdf
17288
17668
  //
17289
- // ref: https://arxiv.org/pdf/1412.6980.pdf
17669
+ // (Original Adam - ref: https://arxiv.org/pdf/1412.6980.pdf)
17290
17670
  //
17291
17671
 
17292
17672
  static enum wsp_ggml_opt_result wsp_ggml_opt_adam(
@@ -17422,7 +17802,7 @@ static enum wsp_ggml_opt_result wsp_ggml_opt_adam(
17422
17802
  int64_t i = 0;
17423
17803
  for (int p = 0; p < np; ++p) {
17424
17804
  const int64_t ne = wsp_ggml_nelements(ps[p]);
17425
- const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
17805
+ const float p_decay = ((wsp_ggml_n_dims(ps[p]) >= decay_min_ndim) ? decay : 0.0f) * sched;
17426
17806
  for (int64_t j = 0; j < ne; ++j) {
17427
17807
  float x = wsp_ggml_get_f32_1d(ps[p], j);
17428
17808
  float g_ = g[i]*gnorm;
@@ -18144,6 +18524,28 @@ enum wsp_ggml_opt_result wsp_ggml_opt_resume_g(
18144
18524
 
18145
18525
  ////////////////////////////////////////////////////////////////////////////////
18146
18526
 
18527
+ void wsp_ggml_wsp_quantize_init(enum wsp_ggml_type type) {
18528
+ wsp_ggml_critical_section_start();
18529
+
18530
+ switch (type) {
18531
+ case WSP_GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break;
18532
+ case WSP_GGML_TYPE_IQ2_XS: iq2xs_init_impl(512); break;
18533
+ default: // nothing
18534
+ break;
18535
+ }
18536
+
18537
+ wsp_ggml_critical_section_end();
18538
+ }
18539
+
18540
+ void wsp_ggml_wsp_quantize_free(void) {
18541
+ wsp_ggml_critical_section_start();
18542
+
18543
+ iq2xs_free_impl(256);
18544
+ iq2xs_free_impl(512);
18545
+
18546
+ wsp_ggml_critical_section_end();
18547
+ }
18548
+
18147
18549
  size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18148
18550
  assert(k % QK4_0 == 0);
18149
18551
  const int nb = k / QK4_0;
@@ -18271,32 +18673,53 @@ size_t wsp_ggml_wsp_quantize_q8_0(const float * src, void * dst, int n, int k, i
18271
18673
  return (n/QK8_0*sizeof(block_q8_0));
18272
18674
  }
18273
18675
 
18274
- size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
18676
+ bool wsp_ggml_wsp_quantize_requires_imatrix(enum wsp_ggml_type type) {
18677
+ return
18678
+ type == WSP_GGML_TYPE_IQ2_XXS ||
18679
+ type == WSP_GGML_TYPE_IQ2_XS;
18680
+ }
18681
+
18682
+ size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start,
18683
+ int nrows, int n_per_row, int64_t * hist, const float * imatrix) {
18684
+ wsp_ggml_wsp_quantize_init(type); // this is noop if already initialized
18275
18685
  size_t result = 0;
18686
+ int n = nrows * n_per_row;
18276
18687
  switch (type) {
18277
18688
  case WSP_GGML_TYPE_Q4_0:
18278
18689
  {
18279
18690
  WSP_GGML_ASSERT(start % QK4_0 == 0);
18280
- block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
18281
- result = wsp_ggml_wsp_quantize_q4_0(src + start, block, n, n, hist);
18691
+ WSP_GGML_ASSERT(start % n_per_row == 0);
18692
+ size_t start_row = start / n_per_row;
18693
+ size_t row_size = wsp_ggml_row_size(type, n_per_row);
18694
+ result = wsp_quantize_q4_0(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
18695
+ WSP_GGML_ASSERT(result == row_size * nrows);
18282
18696
  } break;
18283
18697
  case WSP_GGML_TYPE_Q4_1:
18284
18698
  {
18285
18699
  WSP_GGML_ASSERT(start % QK4_1 == 0);
18286
- block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
18287
- result = wsp_ggml_wsp_quantize_q4_1(src + start, block, n, n, hist);
18700
+ WSP_GGML_ASSERT(start % n_per_row == 0);
18701
+ size_t start_row = start / n_per_row;
18702
+ size_t row_size = wsp_ggml_row_size(type, n_per_row);
18703
+ result = wsp_quantize_q4_1(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
18704
+ WSP_GGML_ASSERT(result == row_size * nrows);
18288
18705
  } break;
18289
18706
  case WSP_GGML_TYPE_Q5_0:
18290
18707
  {
18291
18708
  WSP_GGML_ASSERT(start % QK5_0 == 0);
18292
- block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
18293
- result = wsp_ggml_wsp_quantize_q5_0(src + start, block, n, n, hist);
18709
+ WSP_GGML_ASSERT(start % n_per_row == 0);
18710
+ size_t start_row = start / n_per_row;
18711
+ size_t row_size = wsp_ggml_row_size(type, n_per_row);
18712
+ result = wsp_quantize_q5_0(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
18713
+ WSP_GGML_ASSERT(result == row_size * nrows);
18294
18714
  } break;
18295
18715
  case WSP_GGML_TYPE_Q5_1:
18296
18716
  {
18297
18717
  WSP_GGML_ASSERT(start % QK5_1 == 0);
18298
- block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
18299
- result = wsp_ggml_wsp_quantize_q5_1(src + start, block, n, n, hist);
18718
+ WSP_GGML_ASSERT(start % n_per_row == 0);
18719
+ size_t start_row = start / n_per_row;
18720
+ size_t row_size = wsp_ggml_row_size(type, n_per_row);
18721
+ result = wsp_quantize_q5_1(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
18722
+ WSP_GGML_ASSERT(result == row_size * nrows);
18300
18723
  } break;
18301
18724
  case WSP_GGML_TYPE_Q8_0:
18302
18725
  {
@@ -18307,42 +18730,77 @@ size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, v
18307
18730
  case WSP_GGML_TYPE_Q2_K:
18308
18731
  {
18309
18732
  WSP_GGML_ASSERT(start % QK_K == 0);
18310
- block_q2_K * block = (block_q2_K*)dst + start / QK_K;
18311
- result = wsp_ggml_wsp_quantize_q2_K(src + start, block, n, n, hist);
18733
+ WSP_GGML_ASSERT(start % n_per_row == 0);
18734
+ size_t start_row = start / n_per_row;
18735
+ size_t row_size = wsp_ggml_row_size(type, n_per_row);
18736
+ result = wsp_quantize_q2_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
18737
+ WSP_GGML_ASSERT(result == row_size * nrows);
18312
18738
  } break;
18313
18739
  case WSP_GGML_TYPE_Q3_K:
18314
18740
  {
18315
18741
  WSP_GGML_ASSERT(start % QK_K == 0);
18316
- block_q3_K * block = (block_q3_K*)dst + start / QK_K;
18317
- result = wsp_ggml_wsp_quantize_q3_K(src + start, block, n, n, hist);
18742
+ WSP_GGML_ASSERT(start % n_per_row == 0);
18743
+ size_t start_row = start / n_per_row;
18744
+ size_t row_size = wsp_ggml_row_size(type, n_per_row);
18745
+ result = wsp_quantize_q3_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
18746
+ WSP_GGML_ASSERT(result == row_size * nrows);
18318
18747
  } break;
18319
18748
  case WSP_GGML_TYPE_Q4_K:
18320
18749
  {
18321
18750
  WSP_GGML_ASSERT(start % QK_K == 0);
18322
- block_q4_K * block = (block_q4_K*)dst + start / QK_K;
18323
- result = wsp_ggml_wsp_quantize_q4_K(src + start, block, n, n, hist);
18751
+ WSP_GGML_ASSERT(start % n_per_row == 0);
18752
+ size_t start_row = start / n_per_row;
18753
+ size_t row_size = wsp_ggml_row_size(type, n_per_row);
18754
+ result = wsp_quantize_q4_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
18755
+ WSP_GGML_ASSERT(result == row_size * nrows);
18324
18756
  } break;
18325
18757
  case WSP_GGML_TYPE_Q5_K:
18326
18758
  {
18327
18759
  WSP_GGML_ASSERT(start % QK_K == 0);
18328
- block_q5_K * block = (block_q5_K*)dst + start / QK_K;
18329
- result = wsp_ggml_wsp_quantize_q5_K(src + start, block, n, n, hist);
18760
+ WSP_GGML_ASSERT(start % n_per_row == 0);
18761
+ size_t start_row = start / n_per_row;
18762
+ size_t row_size = wsp_ggml_row_size(type, n_per_row);
18763
+ result = wsp_quantize_q5_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
18764
+ WSP_GGML_ASSERT(result == row_size * nrows);
18330
18765
  } break;
18331
18766
  case WSP_GGML_TYPE_Q6_K:
18332
18767
  {
18333
18768
  WSP_GGML_ASSERT(start % QK_K == 0);
18334
- block_q6_K * block = (block_q6_K*)dst + start / QK_K;
18335
- result = wsp_ggml_wsp_quantize_q6_K(src + start, block, n, n, hist);
18769
+ WSP_GGML_ASSERT(start % n_per_row == 0);
18770
+ size_t start_row = start / n_per_row;
18771
+ size_t row_size = wsp_ggml_row_size(type, n_per_row);
18772
+ result = wsp_quantize_q6_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
18773
+ WSP_GGML_ASSERT(result == row_size * nrows);
18774
+ } break;
18775
+ case WSP_GGML_TYPE_IQ2_XXS:
18776
+ {
18777
+ WSP_GGML_ASSERT(start % QK_K == 0);
18778
+ WSP_GGML_ASSERT(start % n_per_row == 0);
18779
+ WSP_GGML_ASSERT(imatrix);
18780
+ size_t start_row = start / n_per_row;
18781
+ size_t row_size = wsp_ggml_row_size(type, n_per_row);
18782
+ result = wsp_quantize_iq2_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
18783
+ WSP_GGML_ASSERT(result == row_size * nrows);
18784
+ } break;
18785
+ case WSP_GGML_TYPE_IQ2_XS:
18786
+ {
18787
+ WSP_GGML_ASSERT(start % QK_K == 0);
18788
+ WSP_GGML_ASSERT(start % n_per_row == 0);
18789
+ WSP_GGML_ASSERT(imatrix);
18790
+ size_t start_row = start / n_per_row;
18791
+ size_t row_size = wsp_ggml_row_size(type, n_per_row);
18792
+ result = wsp_quantize_iq2_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
18793
+ WSP_GGML_ASSERT(result == row_size * nrows);
18336
18794
  } break;
18337
18795
  case WSP_GGML_TYPE_F16:
18338
18796
  {
18339
- int elemsize = sizeof(wsp_ggml_fp16_t);
18797
+ size_t elemsize = sizeof(wsp_ggml_fp16_t);
18340
18798
  wsp_ggml_fp32_to_fp16_row(src + start, (wsp_ggml_fp16_t *)dst + start, n);
18341
18799
  result = n * elemsize;
18342
18800
  } break;
18343
18801
  case WSP_GGML_TYPE_F32:
18344
18802
  {
18345
- int elemsize = sizeof(float);
18803
+ size_t elemsize = sizeof(float);
18346
18804
  result = n * elemsize;
18347
18805
  memcpy((uint8_t *)dst + start * elemsize, src + start, result);
18348
18806
  } break;
@@ -18689,14 +19147,14 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
18689
19147
  (int64_t) info->ne[3];
18690
19148
 
18691
19149
  if (ne % wsp_ggml_blck_size(info->type) != 0) {
18692
- fprintf(stderr, "%s: tensor '%s' number of elements (%" PRId64 ") is not a multiple of block size (%d)\n",
18693
- __func__, info->name.data, ne, wsp_ggml_blck_size(info->type));
19150
+ fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n",
19151
+ __func__, info->name.data, (int)info->type, wsp_ggml_type_name(info->type), ne, wsp_ggml_blck_size(info->type));
18694
19152
  fclose(file);
18695
19153
  wsp_gguf_free(ctx);
18696
19154
  return NULL;
18697
19155
  }
18698
19156
 
18699
- const size_t size_cur = (ne*wsp_ggml_type_size(info->type))/wsp_ggml_blck_size(info->type);
19157
+ const size_t size_cur = wsp_ggml_row_size(info->type, ne);
18700
19158
 
18701
19159
  ctx->size += WSP_GGML_PAD(size_cur, ctx->alignment);
18702
19160
  }
@@ -18796,7 +19254,7 @@ void wsp_gguf_free(struct wsp_gguf_context * ctx) {
18796
19254
 
18797
19255
  if (ctx->kv) {
18798
19256
  // free string memory - not great..
18799
- for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
19257
+ for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
18800
19258
  struct wsp_gguf_kv * kv = &ctx->kv[i];
18801
19259
 
18802
19260
  if (kv->key.data) {
@@ -18812,7 +19270,7 @@ void wsp_gguf_free(struct wsp_gguf_context * ctx) {
18812
19270
  if (kv->type == WSP_GGUF_TYPE_ARRAY) {
18813
19271
  if (kv->value.arr.data) {
18814
19272
  if (kv->value.arr.type == WSP_GGUF_TYPE_STRING) {
18815
- for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
19273
+ for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
18816
19274
  struct wsp_gguf_str * str = &((struct wsp_gguf_str *) kv->value.arr.data)[j];
18817
19275
  if (str->data) {
18818
19276
  free(str->data);
@@ -18828,7 +19286,7 @@ void wsp_gguf_free(struct wsp_gguf_context * ctx) {
18828
19286
  }
18829
19287
 
18830
19288
  if (ctx->infos) {
18831
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
19289
+ for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
18832
19290
  struct wsp_gguf_tensor_info * info = &ctx->infos[i];
18833
19291
 
18834
19292
  if (info->name.data) {
@@ -19025,6 +19483,10 @@ char * wsp_gguf_get_tensor_name(const struct wsp_gguf_context * ctx, int i) {
19025
19483
  return ctx->infos[i].name.data;
19026
19484
  }
19027
19485
 
19486
+ enum wsp_ggml_type wsp_gguf_get_tensor_type(const struct wsp_gguf_context * ctx, int i) {
19487
+ return ctx->infos[i].type;
19488
+ }
19489
+
19028
19490
  // returns the index
19029
19491
  static int wsp_gguf_get_or_add_key(struct wsp_gguf_context * ctx, const char * key) {
19030
19492
  const int idx = wsp_gguf_find_key(ctx, key);
@@ -19175,7 +19637,7 @@ void wsp_gguf_set_kv(struct wsp_gguf_context * ctx, struct wsp_gguf_context * sr
19175
19637
  data[j] = ((struct wsp_gguf_str *)src->kv[i].value.arr.data)[j].data;
19176
19638
  }
19177
19639
  wsp_gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
19178
- free(data);
19640
+ free((void *)data);
19179
19641
  } else if (src->kv[i].value.arr.type == WSP_GGUF_TYPE_ARRAY) {
19180
19642
  WSP_GGML_ASSERT(false && "nested arrays not supported");
19181
19643
  } else {
@@ -19200,8 +19662,8 @@ void wsp_gguf_add_tensor(
19200
19662
  ctx->infos[idx].ne[i] = 1;
19201
19663
  }
19202
19664
 
19203
- ctx->infos[idx].n_dims = tensor->n_dims;
19204
- for (int i = 0; i < tensor->n_dims; i++) {
19665
+ ctx->infos[idx].n_dims = wsp_ggml_n_dims(tensor);
19666
+ for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) {
19205
19667
  ctx->infos[idx].ne[i] = tensor->ne[i];
19206
19668
  }
19207
19669
 
@@ -19465,6 +19927,14 @@ int wsp_ggml_cpu_has_avx(void) {
19465
19927
  #endif
19466
19928
  }
19467
19929
 
19930
+ int wsp_ggml_cpu_has_avx_vnni(void) {
19931
+ #if defined(__AVXVNNI__)
19932
+ return 1;
19933
+ #else
19934
+ return 0;
19935
+ #endif
19936
+ }
19937
+
19468
19938
  int wsp_ggml_cpu_has_avx2(void) {
19469
19939
  #if defined(__AVX2__)
19470
19940
  return 1;