cui-llama.rn 1.1.4 → 1.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/common.h CHANGED
@@ -4,20 +4,9 @@
4
4
 
5
5
  #include "llama.h"
6
6
 
7
- #include "sampling.h"
8
-
9
- #define LOG_NO_FILE_LINE_FUNCTION
10
- #include "log.h"
11
-
12
- #include <cmath>
13
7
  #include <string>
14
8
  #include <vector>
15
- #include <random>
16
- #include <thread>
17
- #include <set>
18
- #include <unordered_map>
19
- #include <tuple>
20
- #include <functional>
9
+ #include <sstream>
21
10
 
22
11
  #ifdef _WIN32
23
12
  #define DIRECTORY_SEPARATOR '\\'
@@ -67,11 +56,20 @@ extern char const *LLAMA_BUILD_TARGET;
67
56
  // CPU utils
68
57
  //
69
58
 
59
+ struct cpu_params {
60
+ int n_threads = -1;
61
+ bool cpumask[LM_GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
62
+ bool mask_valid = false; // Default: any CPU
63
+ enum lm_ggml_sched_priority priority = LM_GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime)
64
+ bool strict_cpu = false; // Use strict CPU placement
65
+ uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling)
66
+ };
67
+
70
68
  int32_t cpu_get_num_physical_cores();
71
69
  int32_t cpu_get_num_math();
72
70
 
73
71
  //
74
- // CLI argument parsing
72
+ // Common params
75
73
  //
76
74
 
77
75
  enum llama_example {
@@ -89,27 +87,76 @@ enum llama_example {
89
87
  LLAMA_EXAMPLE_CVECTOR_GENERATOR,
90
88
  LLAMA_EXAMPLE_EXPORT_LORA,
91
89
  LLAMA_EXAMPLE_LLAVA,
90
+ LLAMA_EXAMPLE_LOOKUP,
91
+ LLAMA_EXAMPLE_PARALLEL,
92
92
 
93
93
  LLAMA_EXAMPLE_COUNT,
94
94
  };
95
95
 
96
+ enum gpt_sampler_type {
97
+ GPT_SAMPLER_TYPE_NONE = 0,
98
+ GPT_SAMPLER_TYPE_TOP_K = 1,
99
+ GPT_SAMPLER_TYPE_TOP_P = 2,
100
+ GPT_SAMPLER_TYPE_MIN_P = 3,
101
+ GPT_SAMPLER_TYPE_TFS_Z = 4,
102
+ GPT_SAMPLER_TYPE_TYPICAL_P = 5,
103
+ GPT_SAMPLER_TYPE_TEMPERATURE = 6,
104
+ GPT_SAMPLER_TYPE_XTC = 7,
105
+ };
106
+
96
107
  // dimensionality reduction methods, used by cvector-generator
97
108
  enum dimre_method {
98
109
  DIMRE_METHOD_PCA,
99
110
  DIMRE_METHOD_MEAN,
100
111
  };
101
112
 
102
- struct cpu_params {
103
- int n_threads = -1;
104
- bool cpumask[LM_GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
105
- bool mask_valid = false; // Default: any CPU
106
- enum lm_ggml_sched_priority priority = LM_GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime)
107
- bool strict_cpu = false; // Use strict CPU placement
108
- uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling)
113
+ // sampler parameters
114
+ struct gpt_sampler_params {
115
+ uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
116
+
117
+ int32_t n_prev = 64; // number of previous tokens to remember
118
+ int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
119
+ int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
120
+ int32_t top_k = 40; // <= 0 to use vocab size
121
+ float top_p = 0.95f; // 1.0 = disabled
122
+ float min_p = 0.05f; // 0.0 = disabled
123
+ float tfs_z = 1.00f; // 1.0 = disabled
124
+ float xtc_t = 0.0f; // 0.0 = disabled
125
+ float xtc_p = 0.0f;
126
+ float typ_p = 1.00f; // typical_p, 1.0 = disabled
127
+ float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
128
+ float dynatemp_range = 0.00f; // 0.0 = disabled
129
+ float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
130
+ int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
131
+ float penalty_repeat = 1.00f; // 1.0 = disabled
132
+ float penalty_freq = 0.00f; // 0.0 = disabled
133
+ float penalty_present = 0.00f; // 0.0 = disabled
134
+ int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
135
+ float mirostat_tau = 5.00f; // target entropy
136
+ float mirostat_eta = 0.10f; // learning rate
137
+ bool penalize_nl = false; // consider newlines as a repeatable token
138
+ bool ignore_eos = false;
139
+ bool no_perf = false; // disable performance metrics
140
+
141
+ std::vector<enum gpt_sampler_type> samplers = {
142
+ GPT_SAMPLER_TYPE_TOP_K,
143
+ GPT_SAMPLER_TYPE_TFS_Z,
144
+ GPT_SAMPLER_TYPE_TYPICAL_P,
145
+ GPT_SAMPLER_TYPE_TOP_P,
146
+ GPT_SAMPLER_TYPE_MIN_P,
147
+ GPT_SAMPLER_TYPE_TEMPERATURE,
148
+ GPT_SAMPLER_TYPE_XTC
149
+ };
150
+
151
+ std::string grammar; // optional BNF-like grammar to constrain sampling
152
+
153
+ std::vector<llama_logit_bias> logit_bias; // logit biases to apply
154
+
155
+ // print the parameters into a string
156
+ std::string print() const;
109
157
  };
110
158
 
111
159
  struct gpt_params {
112
- enum llama_example curr_ex = LLAMA_EXAMPLE_COMMON;
113
160
 
114
161
  bool vocab_only = false;
115
162
  int32_t n_predict = -1; // new tokens to predict
@@ -155,23 +202,23 @@ struct gpt_params {
155
202
 
156
203
  struct gpt_sampler_params sparams;
157
204
 
158
- std::string model = ""; // model path
159
- std::string model_draft = ""; // draft model for speculative decoding
160
- std::string model_alias = "unknown"; // model alias
161
- std::string model_url = ""; // model url to download
162
- std::string hf_token = ""; // HF token
163
- std::string hf_repo = ""; // HF repo
164
- std::string hf_file = ""; // HF file
165
- std::string prompt = "";
166
- std::string prompt_file = ""; // store the external prompt file name
167
- std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
168
- std::string input_prefix = ""; // string to prefix user inputs with
169
- std::string input_suffix = ""; // string to suffix user inputs with
170
- std::string logdir = ""; // directory in which to save YAML log files
171
- std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding
172
- std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding
173
- std::string logits_file = ""; // file for saving *all* logits
174
- std::string rpc_servers = ""; // comma separated list of RPC servers
205
+ std::string model = ""; // model path // NOLINT
206
+ std::string model_draft = ""; // draft model for speculative decoding // NOLINT
207
+ std::string model_alias = "unknown"; // model alias // NOLINT
208
+ std::string model_url = ""; // model url to download // NOLINT
209
+ std::string hf_token = ""; // HF token // NOLINT
210
+ std::string hf_repo = ""; // HF repo // NOLINT
211
+ std::string hf_file = ""; // HF file // NOLINT
212
+ std::string prompt = ""; // NOLINT
213
+ std::string prompt_file = ""; // store the external prompt file name // NOLINT
214
+ std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
215
+ std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
216
+ std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
217
+ std::string logdir = ""; // directory in which to save YAML log files // NOLINT
218
+ std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
219
+ std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
220
+ std::string logits_file = ""; // file for saving *all* logits // NOLINT
221
+ std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT
175
222
 
176
223
  std::vector<std::string> in_files; // all input files
177
224
  std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
@@ -201,7 +248,6 @@ struct gpt_params {
201
248
 
202
249
  bool kl_divergence = false; // compute KL divergence
203
250
 
204
- std::function<void(int, char **)> print_usage = nullptr; // print example-specific usage and example
205
251
  bool usage = false; // print usage
206
252
  bool use_color = false; // use color to distinguish generations and inputs
207
253
  bool special = false; // enable special token output
@@ -216,6 +262,8 @@ struct gpt_params {
216
262
  bool simple_io = false; // improves compatibility with subprocesses and limited consoles
217
263
  bool cont_batching = true; // insert new sequences for decoding on-the-fly
218
264
  bool flash_attn = false; // flash attention
265
+ bool no_perf = false; // disable performance metrics
266
+ bool ctx_shift = true; // context shift on inifinite text generation
219
267
 
220
268
  bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
221
269
  bool logits_all = false; // return logits for all tokens in the batch
@@ -232,7 +280,7 @@ struct gpt_params {
232
280
  std::string cache_type_v = "f16"; // KV cache data type for the V
233
281
 
234
282
  // multimodal models (see examples/llava)
235
- std::string mmproj = ""; // path to multimodal projector
283
+ std::string mmproj = ""; // path to multimodal projector // NOLINT
236
284
  std::vector<std::string> image; // path to image file(s)
237
285
 
238
286
  // embedding
@@ -248,15 +296,15 @@ struct gpt_params {
248
296
  int n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
249
297
 
250
298
  std::string hostname = "127.0.0.1";
251
- std::string public_path = "";
252
- std::string chat_template = "";
253
- std::string system_prompt = "";
299
+ std::string public_path = ""; // NOLINT
300
+ std::string chat_template = ""; // NOLINT
301
+ std::string system_prompt = ""; // NOLINT
254
302
  bool enable_chat_template = true;
255
303
 
256
304
  std::vector<std::string> api_keys;
257
305
 
258
- std::string ssl_file_key = "";
259
- std::string ssl_file_cert = "";
306
+ std::string ssl_file_key = ""; // NOLINT
307
+ std::string ssl_file_cert = ""; // NOLINT
260
308
 
261
309
  bool endpoint_slots = true;
262
310
  bool endpoint_metrics = false;
@@ -311,91 +359,9 @@ struct gpt_params {
311
359
  bool batched_bench_output_jsonl = false;
312
360
  };
313
361
 
314
- struct llama_arg {
315
- std::set<enum llama_example> examples = {LLAMA_EXAMPLE_COMMON};
316
- std::vector<const char *> args;
317
- const char * value_hint = nullptr; // help text or example for arg value
318
- const char * value_hint_2 = nullptr; // for second arg value
319
- const char * env = nullptr;
320
- std::string help;
321
- void (*handler_void) (gpt_params & params) = nullptr;
322
- void (*handler_string) (gpt_params & params, const std::string &) = nullptr;
323
- void (*handler_str_str)(gpt_params & params, const std::string &, const std::string &) = nullptr;
324
- void (*handler_int) (gpt_params & params, int) = nullptr;
325
-
326
- llama_arg(
327
- const std::initializer_list<const char *> & args,
328
- const char * value_hint,
329
- const std::string & help,
330
- void (*handler)(gpt_params & params, const std::string &)
331
- ) : args(args), value_hint(value_hint), help(help), handler_string(handler) {}
332
-
333
- llama_arg(
334
- const std::initializer_list<const char *> & args,
335
- const char * value_hint,
336
- const std::string & help,
337
- void (*handler)(gpt_params & params, int)
338
- ) : args(args), value_hint(value_hint), help(help), handler_int(handler) {}
339
-
340
- llama_arg(
341
- const std::initializer_list<const char *> & args,
342
- const std::string & help,
343
- void (*handler)(gpt_params & params)
344
- ) : args(args), help(help), handler_void(handler) {}
345
-
346
- // support 2 values for arg
347
- llama_arg(
348
- const std::initializer_list<const char *> & args,
349
- const char * value_hint,
350
- const char * value_hint_2,
351
- const std::string & help,
352
- void (*handler)(gpt_params & params, const std::string &, const std::string &)
353
- ) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {}
354
-
355
- llama_arg & set_examples(std::initializer_list<enum llama_example> examples) {
356
- this->examples = std::move(examples);
357
- return *this;
358
- }
359
-
360
- llama_arg & set_env(const char * env) {
361
- help = help + "\n(env: " + env + ")";
362
- this->env = env;
363
- return *this;
364
- }
365
-
366
- bool in_example(enum llama_example ex) {
367
- return examples.find(ex) != examples.end();
368
- }
369
-
370
- bool get_value_from_env(std::string & output) const {
371
- if (env == nullptr) return false;
372
- char * value = std::getenv(env);
373
- if (value) {
374
- output = value;
375
- return true;
376
- }
377
- return false;
378
- }
379
-
380
- bool has_value_from_env() const {
381
- return env != nullptr && std::getenv(env);
382
- }
383
-
384
- std::string to_string();
385
- };
386
-
387
- // initialize list of options (arguments) that can be used by the current example
388
- std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example ex);
389
- // optionally, we can provide "print_usage" to print example usage
390
- std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example ex, std::function<void(int, char **)> print_usage);
391
-
392
- // parse input arguments from CLI
393
- // if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message)
394
- bool gpt_params_parse (int argc, char ** argv, gpt_params & params, std::vector<llama_arg> & options);
395
- bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vector<llama_arg> & options);
396
-
397
- // print full usage message; it will be called internally by gpt_params_parse() if "-h" is set
398
- void gpt_params_print_usage(gpt_params & params, std::vector<llama_arg> & options);
362
+ // call once at the start of a program if it uses libcommon
363
+ // initializes the logging system and prints info about the build
364
+ void gpt_init();
399
365
 
400
366
  std::string gpt_params_get_system_info(const gpt_params & params);
401
367
 
@@ -432,6 +398,11 @@ static std::vector<T> string_split(const std::string & str, char delim) {
432
398
  bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
433
399
  void string_process_escapes(std::string & input);
434
400
 
401
+ std::string string_from(bool value);
402
+ std::string string_from(const std::vector<int> & values);
403
+ std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
404
+ std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
405
+
435
406
  //
436
407
  // Filesystem utils
437
408
  //
package/cpp/ggml-impl.h CHANGED
@@ -629,8 +629,16 @@ inline static float lm_ggml_lookup_fp16_to_fp32(lm_ggml_fp16_t f) {
629
629
  #define LM_GGML_FP32_TO_FP16(x) LM_GGML_COMPUTE_FP32_TO_FP16(x)
630
630
  #endif
631
631
 
632
+ enum lm_ggml_cgraph_eval_order {
633
+ LM_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
634
+ LM_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
635
+ LM_GGML_CGRAPH_EVAL_ORDER_COUNT
636
+ };
637
+
632
638
  // bitset
633
639
 
640
+ typedef uint32_t lm_ggml_bitset_t;
641
+
634
642
  static_assert(sizeof(lm_ggml_bitset_t) == 4, "bitset_t constants must be updated");
635
643
  #define BITSET_SHR 5 // log2(sizeof(lm_ggml_bitset_t)*8)
636
644
  #define BITSET_MASK (sizeof(lm_ggml_bitset_t)*8 - 1)
@@ -656,6 +664,12 @@ static inline void lm_ggml_bitset_clear(lm_ggml_bitset_t * bitset, size_t i) {
656
664
  #define LM_GGML_HASHSET_FULL ((size_t)-1)
657
665
  #define LM_GGML_HASHSET_ALREADY_EXISTS ((size_t)-2)
658
666
 
667
+ struct lm_ggml_hash_set {
668
+ size_t size;
669
+ lm_ggml_bitset_t * used; // whether or not the keys are in use i.e. set
670
+ struct lm_ggml_tensor ** keys; // actual tensors in the set, keys[i] is only defined if lm_ggml_bitset_get(used, i)
671
+ };
672
+
659
673
  struct lm_ggml_hash_set lm_ggml_hash_set_new(size_t size);
660
674
  void lm_ggml_hash_set_free(struct lm_ggml_hash_set * hash_set);
661
675
 
@@ -745,6 +759,24 @@ static size_t lm_ggml_hash_find_or_insert(struct lm_ggml_hash_set * hash_set, st
745
759
  LM_GGML_ABORT("fatal error");
746
760
  }
747
761
 
762
+ // computation graph
763
+
764
+ struct lm_ggml_cgraph {
765
+ int size;
766
+ int n_nodes;
767
+ int n_leafs;
768
+
769
+ struct lm_ggml_tensor ** nodes;
770
+ struct lm_ggml_tensor ** grads;
771
+ struct lm_ggml_tensor ** leafs;
772
+
773
+ struct lm_ggml_hash_set visited_hash_set;
774
+
775
+ enum lm_ggml_cgraph_eval_order order;
776
+ };
777
+
778
+ struct lm_ggml_cgraph lm_ggml_graph_view(struct lm_ggml_cgraph * cgraph, int i0, int i1);
779
+
748
780
  #ifdef __cplusplus
749
781
  }
750
782
  #endif
package/cpp/ggml-metal.m CHANGED
@@ -1,7 +1,7 @@
1
1
  #import "ggml-metal.h"
2
2
 
3
+ #import "ggml-impl.h"
3
4
  #import "ggml-backend-impl.h"
4
- #import "ggml.h"
5
5
 
6
6
  #import <Foundation/Foundation.h>
7
7
 
@@ -13,13 +13,16 @@
13
13
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
14
14
 
15
15
  #ifdef LM_GGML_METAL_NDEBUG
16
+ #define LM_GGML_METAL_LOG(...)
16
17
  #define LM_GGML_METAL_LOG_INFO(...)
17
18
  #define LM_GGML_METAL_LOG_WARN(...)
18
19
  #define LM_GGML_METAL_LOG_ERROR(...)
19
20
  #else
20
- #define LM_GGML_METAL_LOG_INFO(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_INFO, __VA_ARGS__)
21
- #define LM_GGML_METAL_LOG_WARN(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_WARN, __VA_ARGS__)
21
+ #define LM_GGML_METAL_LOG(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_NONE, __VA_ARGS__)
22
+ #define LM_GGML_METAL_LOG_INFO(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_INFO, __VA_ARGS__)
23
+ #define LM_GGML_METAL_LOG_WARN(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_WARN, __VA_ARGS__)
22
24
  #define LM_GGML_METAL_LOG_ERROR(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
25
+ #define LM_GGML_METAL_LOG_DEBUG(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
23
26
  #endif
24
27
 
25
28
  #define UNUSED(x) (void)(x)
@@ -882,7 +885,7 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
882
885
  // create multiple command buffers and enqueue them
883
886
  // then, we encode the graph into the command buffers in parallel
884
887
 
885
- const int n_nodes = gf->n_nodes;
888
+ const int n_nodes = gf->n_nodes;
886
889
  const int n_cb = ctx->n_cb;
887
890
  const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
888
891
 
@@ -3039,8 +3042,7 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
3039
3042
  if (status != MTLCommandBufferStatusCompleted) {
3040
3043
  LM_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
3041
3044
  if (status == MTLCommandBufferStatusError) {
3042
- NSString * error_code = [command_buffer error].localizedDescription;
3043
- LM_GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]);
3045
+ LM_GGML_METAL_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
3044
3046
  }
3045
3047
 
3046
3048
  return LM_GGML_STATUS_FAILED;
@@ -3184,7 +3186,7 @@ static void lm_ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_
3184
3186
  #ifndef LM_GGML_METAL_NDEBUG
3185
3187
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
3186
3188
  if (@available(macOS 10.12, iOS 16.0, *)) {
3187
- LM_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
3189
+ LM_GGML_METAL_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n",
3188
3190
  __func__,
3189
3191
  size_aligned / 1024.0 / 1024.0,
3190
3192
  device.currentAllocatedSize / 1024.0 / 1024.0,
@@ -3192,8 +3194,6 @@ static void lm_ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_
3192
3194
 
3193
3195
  if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
3194
3196
  LM_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
3195
- } else {
3196
- LM_GGML_METAL_LOG_INFO("\n");
3197
3197
  }
3198
3198
  } else {
3199
3199
  LM_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
@@ -3225,15 +3225,19 @@ LM_GGML_CALL static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_a
3225
3225
  ctx->n_buffers = 1;
3226
3226
 
3227
3227
  if (ctx->all_data != NULL) {
3228
- ctx->buffers[0].data = ctx->all_data;
3229
- ctx->buffers[0].size = size;
3230
- ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
3231
- length:size_aligned
3232
- options:MTLResourceStorageModeShared
3233
- deallocator:nil];
3228
+ ctx->buffers[0].data = ctx->all_data;
3229
+ ctx->buffers[0].size = size;
3230
+ ctx->buffers[0].metal = nil;
3231
+
3232
+ if (size_aligned > 0) {
3233
+ ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
3234
+ length:size_aligned
3235
+ options:MTLResourceStorageModeShared
3236
+ deallocator:nil];
3237
+ }
3234
3238
  }
3235
3239
 
3236
- if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) {
3240
+ if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
3237
3241
  LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3238
3242
  free(ctx);
3239
3243
  lm_ggml_backend_metal_free_device();
@@ -3310,14 +3314,17 @@ LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void
3310
3314
 
3311
3315
  // the buffer fits into the max buffer size allowed by the device
3312
3316
  if (size_aligned <= device.maxBufferLength) {
3313
- ctx->buffers[ctx->n_buffers].data = data;
3314
- ctx->buffers[ctx->n_buffers].size = size;
3317
+ ctx->buffers[ctx->n_buffers].data = data;
3318
+ ctx->buffers[ctx->n_buffers].size = size;
3319
+ ctx->buffers[ctx->n_buffers].metal = nil;
3315
3320
 
3316
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
3321
+ if (size_aligned > 0) {
3322
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
3317
3323
 
3318
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
3319
- LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3320
- return false;
3324
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
3325
+ LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3326
+ return false;
3327
+ }
3321
3328
  }
3322
3329
 
3323
3330
  lm_ggml_backend_metal_log_allocated_size(device, size_aligned);
@@ -3333,14 +3340,17 @@ LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void
3333
3340
  for (size_t i = 0; i < size; i += size_step) {
3334
3341
  const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
3335
3342
 
3336
- ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
3337
- ctx->buffers[ctx->n_buffers].size = size_step_aligned;
3343
+ ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
3344
+ ctx->buffers[ctx->n_buffers].size = size_step_aligned;
3345
+ ctx->buffers[ctx->n_buffers].metal = nil;
3338
3346
 
3339
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
3347
+ if (size_step_aligned > 0) {
3348
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
3340
3349
 
3341
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
3342
- LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
3343
- return false;
3350
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
3351
+ LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
3352
+ return false;
3353
+ }
3344
3354
  }
3345
3355
 
3346
3356
  lm_ggml_backend_metal_log_allocated_size(device, size_step_aligned);