llama_cpp 0.0.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,773 @@
1
+ #pragma once
2
+
3
+ //
4
+ // GGML Tensor Library
5
+ //
6
+ // This documentation is still a work in progress.
7
+ // If you wish some specific topics to be covered, feel free to drop a comment:
8
+ //
9
+ // https://github.com/ggerganov/whisper.cpp/issues/40
10
+ //
11
+ // ## Overview
12
+ //
13
+ // This library implements:
14
+ //
15
+ // - a set of tensor operations
16
+ // - automatic differentiation
17
+ // - basic optimization algorithms
18
+ //
19
+ // The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
20
+ // but is not limited to, the following:
21
+ //
22
+ // - linear regression
23
+ // - support vector machines
24
+ // - neural networks
25
+ //
26
+ // The library allows the user to define a certain function using the available tensor operations. This function
27
+ // definition is represented internally via a computation graph. Each tensor operation in the function definition
28
+ // corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
29
+ // function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
30
+ // using one of the available optimization algorithms.
31
+ //
32
+ // For example, here we define the function: f(x) = a*x^2 + b
33
+ //
34
+ // {
35
+ // struct ggml_init_params params = {
36
+ // .mem_size = 16*1024*1024,
37
+ // .mem_buffer = NULL,
38
+ // };
39
+ //
40
+ // // memory allocation happens here
41
+ // struct ggml_context * ctx = ggml_init(params);
42
+ //
43
+ // struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
44
+ //
45
+ // ggml_set_param(ctx, x); // x is an input variable
46
+ //
47
+ // struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
48
+ // struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
49
+ // struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
50
+ // struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
51
+ //
52
+ // ...
53
+ // }
54
+ //
55
+ // Notice that the function definition above does not involve any actual computation. The computation is performed only
56
+ // when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
57
+ //
58
+ // {
59
+ // ...
60
+ //
61
+ // struct ggml_cgraph gf = ggml_build_forward(f);
62
+ //
63
+ // // set the input variable and parameter values
64
+ // ggml_set_f32(x, 2.0f);
65
+ // ggml_set_f32(a, 3.0f);
66
+ // ggml_set_f32(b, 4.0f);
67
+ //
68
+ // ggml_graph_compute(ctx0, &gf);
69
+ //
70
+ // printf("f = %f\n", ggml_get_f32_1d(f, 0));
71
+ //
72
+ // ...
73
+ // }
74
+ //
75
+ // The actual computation is performed in the ggml_graph_compute() function.
76
+ //
77
+ // The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
78
+ // ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
79
+ // in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
80
+ // and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
81
+ // actually needed.
82
+ //
83
+ // The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
84
+ // differentiation and optimization algorithms.
85
+ //
86
+ // The described approach allows to define the function graph once and then compute its forward or backward graphs
87
+ // multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
88
+ // the user can avoid the memory allocation overhead at runtime.
89
+ //
90
+ // The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
91
+ // citizens, but in theory the library can be extended to support FP8 and integer data types.
92
+ //
93
+ // Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
94
+ // and binary operations. Most of the available operations fall into one of these two categories. With time, it became
95
+ // clear that the library needs to support more complex operations. The way to support these operations is not clear
96
+ // yet, but a few examples are demonstrated in the following operations:
97
+ //
98
+ // - ggml_permute()
99
+ // - ggml_conv_1d_1s()
100
+ // - ggml_conv_1d_2s()
101
+ //
102
+ // For each tensor operator, the library implements a forward and backward computation function. The forward function
103
+ // computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
104
+ // input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
105
+ // calculus class, or watch the following video:
106
+ //
107
+ // What is Automatic Differentiation?
108
+ // https://www.youtube.com/watch?v=wG_nF1awSSY
109
+ //
110
+ //
111
+ // ## Tensor data (struct ggml_tensor)
112
+ //
113
+ // The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
114
+ // the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
115
+ // pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
116
+ //
117
+ // {
118
+ // struct ggml_tensor * c = ggml_add(ctx, a, b);
119
+ //
120
+ // assert(c->src[0] == a);
121
+ // assert(c->src[1] == b);
122
+ // }
123
+ //
124
+ // The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
125
+ // number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
126
+ // to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
127
+ // permutation. All tensor operations have to take the stride into account and not assume that the tensor is
128
+ // contiguous in memory.
129
+ //
130
+ // The data of the tensor is accessed via the "data" pointer. For example:
131
+ //
132
+ // {
133
+ // struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3);
134
+ //
135
+ // // a[1, 2] = 1.0f;
136
+ // *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f;
137
+ //
138
+ // // a[2, 0] = 2.0f;
139
+ // *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f;
140
+ //
141
+ // ...
142
+ // }
143
+ //
144
+ // Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
145
+ //
146
+ // ## The matrix multiplication operator (ggml_mul_mat)
147
+ //
148
+ // TODO
149
+ //
150
+ //
151
+ // ## Multi-threading
152
+ //
153
+ // TODO
154
+ //
155
+ //
156
+ // ## Overview of ggml.c
157
+ //
158
+ // TODO
159
+ //
160
+ //
161
+ // ## SIMD optimizations
162
+ //
163
+ // TODO
164
+ //
165
+ //
166
+ // ## Debugging ggml
167
+ //
168
+ // TODO
169
+ //
170
+ //
171
+
172
+ #ifdef __cplusplus
173
+ extern "C" {
174
+ #endif
175
+
176
+ #include <stdint.h>
177
+ #include <stddef.h>
178
+ #include <stdbool.h>
179
+
180
+ #define GGML_MAX_DIMS 4
181
+ #define GGML_MAX_NODES 4096
182
+ #define GGML_MAX_PARAMS 16
183
+ #define GGML_MAX_CONTEXTS 64
184
+ #define GGML_MAX_OPT 4
185
+
186
+ #ifdef __ARM_NEON
187
+ // we use the built-in 16-bit float type
188
+ typedef __fp16 ggml_fp16_t;
189
+ #else
190
+ typedef uint16_t ggml_fp16_t;
191
+ #endif
192
+
193
+ // convert FP16 <-> FP32
194
+ float ggml_fp16_to_fp32(ggml_fp16_t x);
195
+ ggml_fp16_t ggml_fp32_to_fp16(float x);
196
+
197
+ struct ggml_object;
198
+ struct ggml_context;
199
+
200
+ enum ggml_type {
201
+ GGML_TYPE_Q4_0,
202
+ GGML_TYPE_Q4_1,
203
+ GGML_TYPE_I8,
204
+ GGML_TYPE_I16,
205
+ GGML_TYPE_I32,
206
+ GGML_TYPE_F16,
207
+ GGML_TYPE_F32,
208
+ GGML_TYPE_COUNT,
209
+ };
210
+
211
+ // available tensor operations:
212
+ enum ggml_op {
213
+ GGML_OP_NONE = 0,
214
+
215
+ GGML_OP_DUP,
216
+ GGML_OP_ADD,
217
+ GGML_OP_SUB,
218
+ GGML_OP_MUL,
219
+ GGML_OP_DIV,
220
+ GGML_OP_SQR,
221
+ GGML_OP_SQRT,
222
+ GGML_OP_SUM,
223
+ GGML_OP_MEAN,
224
+ GGML_OP_REPEAT,
225
+ GGML_OP_ABS,
226
+ GGML_OP_SGN,
227
+ GGML_OP_NEG,
228
+ GGML_OP_STEP,
229
+ GGML_OP_RELU,
230
+ GGML_OP_GELU,
231
+ GGML_OP_SILU,
232
+ GGML_OP_NORM, // normalize
233
+ GGML_OP_RMS_NORM,
234
+
235
+ GGML_OP_MUL_MAT,
236
+
237
+ GGML_OP_SCALE,
238
+ GGML_OP_CPY,
239
+ GGML_OP_RESHAPE,
240
+ GGML_OP_VIEW,
241
+ GGML_OP_PERMUTE,
242
+ GGML_OP_TRANSPOSE,
243
+ GGML_OP_GET_ROWS,
244
+ GGML_OP_DIAG_MASK_INF,
245
+ GGML_OP_SOFT_MAX,
246
+ GGML_OP_ROPE,
247
+ GGML_OP_CONV_1D_1S,
248
+ GGML_OP_CONV_1D_2S,
249
+
250
+ GGML_OP_FLASH_ATTN,
251
+ GGML_OP_FLASH_FF,
252
+
253
+ GGML_OP_COUNT,
254
+ };
255
+
256
+ // n-dimensional tensor
257
+ struct ggml_tensor {
258
+ enum ggml_type type;
259
+
260
+ int n_dims;
261
+ int ne[GGML_MAX_DIMS]; // number of elements
262
+ size_t nb[GGML_MAX_DIMS]; // stride in bytes:
263
+ // nb[0] = sizeof(type)
264
+ // nb[1] = nb[0] * ne[0] + padding
265
+ // nb[i] = nb[i-1] * ne[i-1]
266
+
267
+ // compute data
268
+ enum ggml_op op;
269
+
270
+ bool is_param;
271
+
272
+ struct ggml_tensor * grad;
273
+ struct ggml_tensor * src0;
274
+ struct ggml_tensor * src1;
275
+ struct ggml_tensor * opt[GGML_MAX_OPT];
276
+
277
+ // thread scheduling
278
+ int n_tasks;
279
+
280
+ // performance
281
+ int perf_runs;
282
+ int64_t perf_cycles;
283
+ int64_t perf_time_us;
284
+
285
+ void * data;
286
+ char padding[8];
287
+ };
288
+
289
+ // computation graph
290
+ struct ggml_cgraph {
291
+ int n_nodes;
292
+ int n_leafs;
293
+ int n_threads;
294
+
295
+ size_t work_size;
296
+ struct ggml_tensor * work;
297
+
298
+ struct ggml_tensor * nodes[GGML_MAX_NODES];
299
+ struct ggml_tensor * grads[GGML_MAX_NODES];
300
+ struct ggml_tensor * leafs[GGML_MAX_NODES];
301
+
302
+ // performance
303
+ int perf_runs;
304
+ int64_t perf_cycles;
305
+ int64_t perf_time_us;
306
+ };
307
+
308
+ // scratch buffer
309
+ struct ggml_scratch {
310
+ size_t offs;
311
+ size_t size;
312
+ void * data;
313
+ };
314
+
315
+ struct ggml_init_params {
316
+ // memory pool
317
+ size_t mem_size; // bytes
318
+ void * mem_buffer; // if NULL, memory will be allocated internally
319
+ };
320
+
321
+ void ggml_time_init(void); // call this once at the beginning of the program
322
+ int64_t ggml_time_ms(void);
323
+ int64_t ggml_time_us(void);
324
+ int64_t ggml_cycles(void);
325
+ int64_t ggml_cycles_per_ms(void);
326
+
327
+ void ggml_print_object (const struct ggml_object * obj);
328
+ void ggml_print_objects(const struct ggml_context * ctx);
329
+
330
+ int ggml_nelements(const struct ggml_tensor * tensor);
331
+ size_t ggml_nbytes (const struct ggml_tensor * tensor);
332
+
333
+ int ggml_blck_size (enum ggml_type type);
334
+ size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
335
+ float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
336
+
337
+ size_t ggml_element_size(const struct ggml_tensor * tensor);
338
+
339
+ struct ggml_context * ggml_init(struct ggml_init_params params);
340
+ void ggml_free(struct ggml_context * ctx);
341
+
342
+ size_t ggml_used_mem(const struct ggml_context * ctx);
343
+
344
+ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
345
+
346
+ bool ggml_mlock_supported(void);
347
+ bool ggml_mlock(struct ggml_context * ctx, char ** err_p);
348
+
349
+ struct ggml_tensor * ggml_new_tensor(
350
+ struct ggml_context * ctx,
351
+ enum ggml_type type,
352
+ int n_dims,
353
+ const int *ne);
354
+
355
+ struct ggml_tensor * ggml_new_tensor_1d(
356
+ struct ggml_context * ctx,
357
+ enum ggml_type type,
358
+ int ne0);
359
+
360
+ struct ggml_tensor * ggml_new_tensor_2d(
361
+ struct ggml_context * ctx,
362
+ enum ggml_type type,
363
+ int ne0,
364
+ int ne1);
365
+
366
+ struct ggml_tensor * ggml_new_tensor_3d(
367
+ struct ggml_context * ctx,
368
+ enum ggml_type type,
369
+ int ne0,
370
+ int ne1,
371
+ int ne2);
372
+
373
+ struct ggml_tensor * ggml_new_tensor_4d(
374
+ struct ggml_context * ctx,
375
+ enum ggml_type type,
376
+ int ne0,
377
+ int ne1,
378
+ int ne2,
379
+ int ne3);
380
+
381
+ struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
382
+ struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
383
+
384
+ struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
385
+ struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
386
+
387
+ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
388
+ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
389
+ struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
390
+
391
+ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
392
+ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
393
+
394
+ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
395
+ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
396
+
397
+ void * ggml_get_data (const struct ggml_tensor * tensor);
398
+ float * ggml_get_data_f32(const struct ggml_tensor * tensor);
399
+
400
+ //
401
+ // operations on tensors with backpropagation
402
+ //
403
+
404
+ struct ggml_tensor * ggml_dup(
405
+ struct ggml_context * ctx,
406
+ struct ggml_tensor * a);
407
+
408
+ struct ggml_tensor * ggml_add(
409
+ struct ggml_context * ctx,
410
+ struct ggml_tensor * a,
411
+ struct ggml_tensor * b);
412
+
413
+ struct ggml_tensor * ggml_sub(
414
+ struct ggml_context * ctx,
415
+ struct ggml_tensor * a,
416
+ struct ggml_tensor * b);
417
+
418
+ struct ggml_tensor * ggml_mul(
419
+ struct ggml_context * ctx,
420
+ struct ggml_tensor * a,
421
+ struct ggml_tensor * b);
422
+
423
+ struct ggml_tensor * ggml_div(
424
+ struct ggml_context * ctx,
425
+ struct ggml_tensor * a,
426
+ struct ggml_tensor * b);
427
+
428
+ struct ggml_tensor * ggml_sqr(
429
+ struct ggml_context * ctx,
430
+ struct ggml_tensor * a);
431
+
432
+ struct ggml_tensor * ggml_sqrt(
433
+ struct ggml_context * ctx,
434
+ struct ggml_tensor * a);
435
+
436
+ // return scalar
437
+ // TODO: compute sum along rows
438
+ struct ggml_tensor * ggml_sum(
439
+ struct ggml_context * ctx,
440
+ struct ggml_tensor * a);
441
+
442
+ // mean along rows
443
+ struct ggml_tensor * ggml_mean(
444
+ struct ggml_context * ctx,
445
+ struct ggml_tensor * a);
446
+
447
+ // if a is the same shape as b, and a is not parameter, return a
448
+ // otherwise, return a new tensor: repeat(a) to fit in b
449
+ struct ggml_tensor * ggml_repeat(
450
+ struct ggml_context * ctx,
451
+ struct ggml_tensor * a,
452
+ struct ggml_tensor * b);
453
+
454
+ struct ggml_tensor * ggml_abs(
455
+ struct ggml_context * ctx,
456
+ struct ggml_tensor * a);
457
+
458
+ struct ggml_tensor * ggml_sgn(
459
+ struct ggml_context * ctx,
460
+ struct ggml_tensor * a);
461
+
462
+ struct ggml_tensor * ggml_neg(
463
+ struct ggml_context * ctx,
464
+ struct ggml_tensor * a);
465
+
466
+ struct ggml_tensor * ggml_step(
467
+ struct ggml_context * ctx,
468
+ struct ggml_tensor * a);
469
+
470
+ struct ggml_tensor * ggml_relu(
471
+ struct ggml_context * ctx,
472
+ struct ggml_tensor * a);
473
+
474
+ // TODO: double-check this computation is correct
475
+ struct ggml_tensor * ggml_gelu(
476
+ struct ggml_context * ctx,
477
+ struct ggml_tensor * a);
478
+
479
+ struct ggml_tensor * ggml_silu(
480
+ struct ggml_context * ctx,
481
+ struct ggml_tensor * a);
482
+
483
+ // normalize along rows
484
+ // TODO: eps is hardcoded to 1e-5 for now
485
+ struct ggml_tensor * ggml_norm(
486
+ struct ggml_context * ctx,
487
+ struct ggml_tensor * a);
488
+
489
+ struct ggml_tensor * ggml_rms_norm(
490
+ struct ggml_context * ctx,
491
+ struct ggml_tensor * a);
492
+
493
+ // A: m rows, n columns
494
+ // B: p rows, n columns (i.e. we transpose it internally)
495
+ // result is m columns, p rows
496
+ struct ggml_tensor * ggml_mul_mat(
497
+ struct ggml_context * ctx,
498
+ struct ggml_tensor * a,
499
+ struct ggml_tensor * b);
500
+
501
+ //
502
+ // operations on tensors without backpropagation
503
+ //
504
+
505
+ // in-place, returns view(a)
506
+ struct ggml_tensor * ggml_scale(
507
+ struct ggml_context * ctx,
508
+ struct ggml_tensor * a,
509
+ struct ggml_tensor * b);
510
+
511
+ // a -> b, return view(b)
512
+ struct ggml_tensor * ggml_cpy(
513
+ struct ggml_context * ctx,
514
+ struct ggml_tensor * a,
515
+ struct ggml_tensor * b);
516
+
517
+ // return view(a), b specifies the new shape
518
+ // TODO: when we start computing gradient, make a copy instead of view
519
+ struct ggml_tensor * ggml_reshape(
520
+ struct ggml_context * ctx,
521
+ struct ggml_tensor * a,
522
+ struct ggml_tensor * b);
523
+
524
+ // return view(a)
525
+ // TODO: when we start computing gradient, make a copy instead of view
526
+ struct ggml_tensor * ggml_reshape_2d(
527
+ struct ggml_context * ctx,
528
+ struct ggml_tensor * a,
529
+ int ne0,
530
+ int ne1);
531
+
532
+ // return view(a)
533
+ // TODO: when we start computing gradient, make a copy instead of view
534
+ struct ggml_tensor * ggml_reshape_3d(
535
+ struct ggml_context * ctx,
536
+ struct ggml_tensor * a,
537
+ int ne0,
538
+ int ne1,
539
+ int ne2);
540
+
541
+ // offset in bytes
542
+ struct ggml_tensor * ggml_view_1d(
543
+ struct ggml_context * ctx,
544
+ struct ggml_tensor * a,
545
+ int ne0,
546
+ size_t offset);
547
+
548
+ struct ggml_tensor * ggml_view_2d(
549
+ struct ggml_context * ctx,
550
+ struct ggml_tensor * a,
551
+ int ne0,
552
+ int ne1,
553
+ size_t nb1, // row stride in bytes
554
+ size_t offset);
555
+
556
+ struct ggml_tensor * ggml_permute(
557
+ struct ggml_context * ctx,
558
+ struct ggml_tensor * a,
559
+ int axis0,
560
+ int axis1,
561
+ int axis2,
562
+ int axis3);
563
+
564
+ // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
565
+ struct ggml_tensor * ggml_transpose(
566
+ struct ggml_context * ctx,
567
+ struct ggml_tensor * a);
568
+
569
+ struct ggml_tensor * ggml_get_rows(
570
+ struct ggml_context * ctx,
571
+ struct ggml_tensor * a,
572
+ struct ggml_tensor * b);
573
+
574
+ // set elements above the diagonal to -INF
575
+ // in-place, returns view(a)
576
+ struct ggml_tensor * ggml_diag_mask_inf(
577
+ struct ggml_context * ctx,
578
+ struct ggml_tensor * a,
579
+ int n_past);
580
+
581
+ // in-place, returns view(a)
582
+ struct ggml_tensor * ggml_soft_max(
583
+ struct ggml_context * ctx,
584
+ struct ggml_tensor * a);
585
+
586
+ // rotary position embedding
587
+ // in-place, returns view(a)
588
+ // if mode == 1, skip n_past elements
589
+ // TODO: avoid creating a new tensor every time
590
+ struct ggml_tensor * ggml_rope(
591
+ struct ggml_context * ctx,
592
+ struct ggml_tensor * a,
593
+ int n_past,
594
+ int n_dims,
595
+ int mode);
596
+
597
+ // padding = 1
598
+ // TODO: we don't support extra parameters for now
599
+ // that's why we are hard-coding the stride, padding, and dilation
600
+ // not great ..
601
+ struct ggml_tensor * ggml_conv_1d_1s(
602
+ struct ggml_context * ctx,
603
+ struct ggml_tensor * a,
604
+ struct ggml_tensor * b);
605
+
606
+ struct ggml_tensor * ggml_conv_1d_2s(
607
+ struct ggml_context * ctx,
608
+ struct ggml_tensor * a,
609
+ struct ggml_tensor * b);
610
+
611
+ struct ggml_tensor * ggml_flash_attn(
612
+ struct ggml_context * ctx,
613
+ struct ggml_tensor * q,
614
+ struct ggml_tensor * k,
615
+ struct ggml_tensor * v,
616
+ bool masked);
617
+
618
+ struct ggml_tensor * ggml_flash_ff(
619
+ struct ggml_context * ctx,
620
+ struct ggml_tensor * a,
621
+ struct ggml_tensor * b0,
622
+ struct ggml_tensor * b1,
623
+ struct ggml_tensor * c0,
624
+ struct ggml_tensor * c1);
625
+
626
+ //
627
+ // automatic differentiation
628
+ //
629
+
630
+ void ggml_set_param(
631
+ struct ggml_context * ctx,
632
+ struct ggml_tensor * tensor);
633
+
634
+ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
635
+
636
+ struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
637
+ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
638
+
639
+ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
640
+ void ggml_graph_reset (struct ggml_cgraph * cgraph);
641
+
642
+ // print info and performance information for the graph
643
+ void ggml_graph_print(const struct ggml_cgraph * cgraph);
644
+
645
+ // dump the graph into a file using the dot format
646
+ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
647
+
648
+ //
649
+ // optimization
650
+ //
651
+
652
+ // optimization methods
653
+ enum ggml_opt_type {
654
+ GGML_OPT_ADAM,
655
+ GGML_OPT_LBFGS,
656
+ };
657
+
658
+ // linesearch methods
659
+ enum ggml_linesearch {
660
+ GGML_LINESEARCH_DEFAULT = 1,
661
+
662
+ GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
663
+ GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
664
+ GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
665
+ };
666
+
667
+ // optimization return values
668
+ enum ggml_opt_result {
669
+ GGML_OPT_OK = 0,
670
+ GGML_OPT_DID_NOT_CONVERGE,
671
+ GGML_OPT_NO_CONTEXT,
672
+ GGML_OPT_INVALID_WOLFE,
673
+ GGML_OPT_FAIL,
674
+
675
+ GGML_LINESEARCH_FAIL = -128,
676
+ GGML_LINESEARCH_MINIMUM_STEP,
677
+ GGML_LINESEARCH_MAXIMUM_STEP,
678
+ GGML_LINESEARCH_MAXIMUM_ITERATIONS,
679
+ GGML_LINESEARCH_INVALID_PARAMETERS,
680
+ };
681
+
682
+ // optimization parameters
683
+ //
684
+ // see ggml.c (ggml_opt_default_params) for default values
685
+ //
686
+ struct ggml_opt_params {
687
+ enum ggml_opt_type type;
688
+
689
+ int n_threads;
690
+
691
+ // delta-based convergence test
692
+ //
693
+ // if past == 0 - disabled
694
+ // if past > 0:
695
+ // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
696
+ //
697
+ int past;
698
+ float delta;
699
+
700
+ // maximum number of iterations without improvement
701
+ //
702
+ // if 0 - disabled
703
+ // if > 0:
704
+ // assume convergence if no cost improvement in this number of iterations
705
+ //
706
+ int max_no_improvement;
707
+
708
+ bool print_forward_graph;
709
+ bool print_backward_graph;
710
+
711
+ // ADAM parameters
712
+ struct {
713
+ int n_iter;
714
+
715
+ float alpha; // learning rate
716
+ float beta1;
717
+ float beta2;
718
+ float eps; // epsilon for numerical stability
719
+ float eps_f; // epsilon for convergence test
720
+ float eps_g; // epsilon for convergence test
721
+ } adam;
722
+
723
+ // LBFGS parameters
724
+ struct {
725
+ int m; // number of corrections to approximate the inv. Hessian
726
+ int n_iter;
727
+ int max_linesearch;
728
+
729
+ float eps; // convergence tolerance
730
+ float ftol; // line search tolerance
731
+ float wolfe;
732
+ float min_step;
733
+ float max_step;
734
+
735
+ enum ggml_linesearch linesearch;
736
+ } lbfgs;
737
+ };
738
+
739
+ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
740
+
741
+ // optimize the function defined by the tensor f
742
+ enum ggml_opt_result ggml_opt(
743
+ struct ggml_context * ctx,
744
+ struct ggml_opt_params params,
745
+ struct ggml_tensor * f);
746
+
747
+ //
748
+ // quantization
749
+ //
750
+
751
+ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
752
+ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
753
+
754
+ //
755
+ // system info
756
+ //
757
+
758
+ int ggml_cpu_has_avx(void);
759
+ int ggml_cpu_has_avx2(void);
760
+ int ggml_cpu_has_avx512(void);
761
+ int ggml_cpu_has_fma(void);
762
+ int ggml_cpu_has_neon(void);
763
+ int ggml_cpu_has_arm_fma(void);
764
+ int ggml_cpu_has_f16c(void);
765
+ int ggml_cpu_has_fp16_va(void);
766
+ int ggml_cpu_has_wasm_simd(void);
767
+ int ggml_cpu_has_blas(void);
768
+ int ggml_cpu_has_sse3(void);
769
+ int ggml_cpu_has_vsx(void);
770
+
771
+ #ifdef __cplusplus
772
+ }
773
+ #endif