@fugood/llama.node 1.2.0 → 1.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -57,12 +57,32 @@ static std::string read_file(const std::string & fname) {
57
57
  }
58
58
 
59
59
  static void write_file(const std::string & fname, const std::string & content) {
60
- std::ofstream file(fname);
60
+ const std::string fname_tmp = fname + ".tmp";
61
+ std::ofstream file(fname_tmp);
61
62
  if (!file) {
62
63
  throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
63
64
  }
64
- file << content;
65
- file.close();
65
+
66
+ try {
67
+ file << content;
68
+ file.close();
69
+
70
+ // Makes write atomic
71
+ if (rename(fname_tmp.c_str(), fname.c_str()) != 0) {
72
+ LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str());
73
+ // If rename fails, try to delete the temporary file
74
+ if (remove(fname_tmp.c_str()) != 0) {
75
+ LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
76
+ }
77
+ }
78
+ } catch (...) {
79
+ // If anything fails, try to delete the temporary file
80
+ if (remove(fname_tmp.c_str()) != 0) {
81
+ LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
82
+ }
83
+
84
+ throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str()));
85
+ }
66
86
  }
67
87
 
68
88
  common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
@@ -217,250 +237,294 @@ struct curl_slist_ptr {
217
237
  }
218
238
  };
219
239
 
220
- #define CURL_MAX_RETRY 3
221
- #define CURL_RETRY_DELAY_SECONDS 2
240
+ static CURLcode common_curl_perf(CURL * curl) {
241
+ CURLcode res = curl_easy_perform(curl);
242
+ if (res != CURLE_OK) {
243
+ LOG_ERR("%s: curl_easy_perform() failed\n", __func__);
244
+ }
245
+
246
+ return res;
247
+ }
222
248
 
223
- static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds, const char * method_name) {
224
- int remaining_attempts = max_attempts;
249
+ // Send a HEAD request to retrieve the etag and last-modified headers
250
+ struct common_load_model_from_url_headers {
251
+ std::string etag;
252
+ std::string last_modified;
253
+ std::string accept_ranges;
254
+ };
225
255
 
226
- while (remaining_attempts > 0) {
227
- LOG_INF("%s: %s %s (attempt %d of %d)...\n", __func__ , method_name, url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
256
+ struct FILE_deleter {
257
+ void operator()(FILE * f) const { fclose(f); }
258
+ };
228
259
 
229
- CURLcode res = curl_easy_perform(curl);
230
- if (res == CURLE_OK) {
231
- return true;
260
+ static size_t common_header_callback(char * buffer, size_t, size_t n_items, void * userdata) {
261
+ common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
262
+ static std::regex header_regex("([^:]+): (.*)\r\n");
263
+ static std::regex etag_regex("ETag", std::regex_constants::icase);
264
+ static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
265
+ static std::regex accept_ranges_regex("Accept-Ranges", std::regex_constants::icase);
266
+ std::string header(buffer, n_items);
267
+ std::smatch match;
268
+ if (std::regex_match(header, match, header_regex)) {
269
+ const std::string & key = match[1];
270
+ const std::string & value = match[2];
271
+ if (std::regex_match(key, match, etag_regex)) {
272
+ headers->etag = value;
273
+ } else if (std::regex_match(key, match, last_modified_regex)) {
274
+ headers->last_modified = value;
275
+ } else if (std::regex_match(key, match, accept_ranges_regex)) {
276
+ headers->accept_ranges = value;
232
277
  }
278
+ }
279
+
280
+ return n_items;
281
+ }
233
282
 
234
- int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
235
- LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
283
+ static size_t common_write_callback(void * data, size_t size, size_t nmemb, void * fd) {
284
+ return std::fwrite(data, size, nmemb, static_cast<FILE *>(fd));
285
+ }
236
286
 
237
- remaining_attempts--;
238
- if (remaining_attempts == 0) break;
239
- std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
287
+ // helper function to hide password in URL
288
+ static std::string llama_download_hide_password_in_url(const std::string & url) {
289
+ // Use regex to match and replace the user[:password]@ pattern in URLs
290
+ // Pattern: scheme://[user[:password]@]host[...]
291
+ static const std::regex url_regex(R"(^(?:[A-Za-z][A-Za-z0-9+.-]://)(?:[^/@]+@)?.$)");
292
+ std::smatch match;
293
+
294
+ if (std::regex_match(url, match, url_regex)) {
295
+ // match[1] = scheme (e.g., "https://")
296
+ // match[2] = user[:password]@ part
297
+ // match[3] = rest of URL (host and path)
298
+ return match[1].str() + "********@" + match[3].str();
240
299
  }
241
300
 
242
- LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
301
+ return url; // No credentials found or malformed URL
302
+ }
243
303
 
244
- return false;
304
+ static void common_curl_easy_setopt_head(CURL * curl, const std::string & url) {
305
+ // Set the URL, allow to follow http redirection
306
+ curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
307
+ curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
308
+
309
+ # if defined(_WIN32)
310
+ // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
311
+ // operating system. Currently implemented under MS-Windows.
312
+ curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
313
+ # endif
314
+
315
+ curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
316
+ curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress
317
+ curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, common_header_callback);
245
318
  }
246
319
 
247
- // download one single file from remote URL to local path
248
- static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) {
249
- // Check if the file already exists locally
250
- auto file_exists = std::filesystem::exists(path);
320
+ static void common_curl_easy_setopt_get(CURL * curl) {
321
+ curl_easy_setopt(curl, CURLOPT_NOBODY, 0L);
322
+ curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, common_write_callback);
251
323
 
252
- // If the file exists, check its JSON metadata companion file.
253
- std::string metadata_path = path + ".json";
254
- nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
255
- std::string etag;
256
- std::string last_modified;
324
+ // display download progress
325
+ curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
326
+ }
257
327
 
258
- if (file_exists) {
259
- if (offline) {
260
- LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
261
- return true; // skip verification/downloading
262
- }
263
- // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
264
- std::ifstream metadata_in(metadata_path);
265
- if (metadata_in.good()) {
266
- try {
267
- metadata_in >> metadata;
268
- LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
269
- if (metadata.contains("etag") && metadata.at("etag").is_string()) {
270
- etag = metadata.at("etag");
271
- }
272
- if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
273
- last_modified = metadata.at("lastModified");
274
- }
275
- } catch (const nlohmann::json::exception & e) {
276
- LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
277
- }
278
- }
279
- // if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
280
- } else {
281
- if (offline) {
282
- LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
283
- return false;
284
- }
285
- LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
328
+ static bool common_pull_file(CURL * curl, const std::string & path_temporary) {
329
+ if (std::filesystem::exists(path_temporary)) {
330
+ const std::string partial_size = std::to_string(std::filesystem::file_size(path_temporary));
331
+ LOG_INF("%s: server supports range requests, resuming download from byte %s\n", __func__, partial_size.c_str());
332
+ const std::string range_str = partial_size + "-";
333
+ curl_easy_setopt(curl, CURLOPT_RANGE, range_str.c_str());
286
334
  }
287
335
 
288
- // Send a HEAD request to retrieve the etag and last-modified headers
289
- struct common_load_model_from_url_headers {
290
- std::string etag;
291
- std::string last_modified;
292
- };
336
+ // Always open file in append mode could be resuming
337
+ std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "ab"));
338
+ if (!outfile) {
339
+ LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str());
340
+ return false;
341
+ }
293
342
 
294
- common_load_model_from_url_headers headers;
295
- bool head_request_ok = false;
296
- bool should_download = !file_exists; // by default, we should download if the file does not exist
343
+ common_curl_easy_setopt_get(curl);
344
+ curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile.get());
297
345
 
298
- // Initialize libcurl
299
- curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
300
- curl_slist_ptr http_headers;
346
+ return common_curl_perf(curl) == CURLE_OK;
347
+ }
348
+
349
+ static bool common_download_head(CURL * curl,
350
+ curl_slist_ptr & http_headers,
351
+ const std::string & url,
352
+ const std::string & bearer_token) {
301
353
  if (!curl) {
302
354
  LOG_ERR("%s: error initializing libcurl\n", __func__);
303
355
  return false;
304
356
  }
305
357
 
306
- // Set the URL, allow to follow http redirection
307
- curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
308
- curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
309
-
310
358
  http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
311
359
  // Check if hf-token or bearer-token was specified
312
360
  if (!bearer_token.empty()) {
313
361
  std::string auth_header = "Authorization: Bearer " + bearer_token;
314
- http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
362
+ http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
315
363
  }
316
- curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
317
364
 
318
- #if defined(_WIN32)
319
- // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
320
- // operating system. Currently implemented under MS-Windows.
321
- curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
322
- #endif
323
-
324
- typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
325
- auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
326
- common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
327
-
328
- static std::regex header_regex("([^:]+): (.*)\r\n");
329
- static std::regex etag_regex("ETag", std::regex_constants::icase);
330
- static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
365
+ curl_easy_setopt(curl, CURLOPT_HTTPHEADER, http_headers.ptr);
366
+ common_curl_easy_setopt_head(curl, url);
367
+ return common_curl_perf(curl) == CURLE_OK;
368
+ }
331
369
 
332
- std::string header(buffer, n_items);
333
- std::smatch match;
334
- if (std::regex_match(header, match, header_regex)) {
335
- const std::string & key = match[1];
336
- const std::string & value = match[2];
337
- if (std::regex_match(key, match, etag_regex)) {
338
- headers->etag = value;
339
- } else if (std::regex_match(key, match, last_modified_regex)) {
340
- headers->last_modified = value;
370
+ // download one single file from remote URL to local path
371
+ static bool common_download_file_single(const std::string & url,
372
+ const std::string & path,
373
+ const std::string & bearer_token,
374
+ bool offline) {
375
+ // If the file exists, check its JSON metadata companion file.
376
+ std::string metadata_path = path + ".json";
377
+ static const int max_attempts = 3;
378
+ static const int retry_delay_seconds = 2;
379
+ for (int i = 0; i < max_attempts; ++i) {
380
+ nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
381
+ std::string etag;
382
+ std::string last_modified;
383
+
384
+ // Check if the file already exists locally
385
+ const auto file_exists = std::filesystem::exists(path);
386
+ if (file_exists) {
387
+ if (offline) {
388
+ LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
389
+ return true; // skip verification/downloading
390
+ }
391
+ // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
392
+ std::ifstream metadata_in(metadata_path);
393
+ if (metadata_in.good()) {
394
+ try {
395
+ metadata_in >> metadata;
396
+ LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
397
+ metadata.dump().c_str());
398
+ if (metadata.contains("etag") && metadata.at("etag").is_string()) {
399
+ etag = metadata.at("etag");
400
+ }
401
+ if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
402
+ last_modified = metadata.at("lastModified");
403
+ }
404
+ } catch (const nlohmann::json::exception & e) {
405
+ LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
406
+ }
341
407
  }
408
+ // if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
409
+ } else {
410
+ if (offline) {
411
+ LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
412
+ return false;
413
+ }
414
+ LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
342
415
  }
343
- return n_items;
344
- };
345
-
346
- curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
347
- curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
348
- curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
349
- curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
350
416
 
351
- // we only allow retrying once for HEAD requests
352
- // this is for the use case of using running offline (no internet), retrying can be annoying
353
- bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD");
354
- if (!was_perform_successful) {
355
- head_request_ok = false;
356
- }
357
-
358
- long http_code = 0;
359
- curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
360
- if (http_code == 200) {
361
- head_request_ok = true;
362
- } else {
363
- LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
364
- head_request_ok = false;
365
- }
417
+ bool head_request_ok = false;
418
+ bool should_download = !file_exists; // by default, we should download if the file does not exist
366
419
 
367
- // if head_request_ok is false, we don't have the etag or last-modified headers
368
- // we leave should_download as-is, which is true if the file does not exist
369
- if (head_request_ok) {
370
- // check if ETag or Last-Modified headers are different
371
- // if it is, we need to download the file again
372
- if (!etag.empty() && etag != headers.etag) {
373
- LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
374
- should_download = true;
375
- } else if (!last_modified.empty() && last_modified != headers.last_modified) {
376
- LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
377
- should_download = true;
420
+ // Initialize libcurl
421
+ curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
422
+ common_load_model_from_url_headers headers;
423
+ curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
424
+ curl_slist_ptr http_headers;
425
+ const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
426
+ if (!was_perform_successful) {
427
+ head_request_ok = false;
378
428
  }
379
- }
380
429
 
381
- if (should_download) {
382
- std::string path_temporary = path + ".downloadInProgress";
383
- if (file_exists) {
384
- LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
385
- if (remove(path.c_str()) != 0) {
386
- LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
387
- return false;
430
+ long http_code = 0;
431
+ curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
432
+ if (http_code == 200) {
433
+ head_request_ok = true;
434
+ } else {
435
+ LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
436
+ head_request_ok = false;
437
+ }
438
+
439
+ // if head_request_ok is false, we don't have the etag or last-modified headers
440
+ // we leave should_download as-is, which is true if the file does not exist
441
+ bool should_download_from_scratch = false;
442
+ if (head_request_ok) {
443
+ // check if ETag or Last-Modified headers are different
444
+ // if it is, we need to download the file again
445
+ if (!etag.empty() && etag != headers.etag) {
446
+ LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(),
447
+ headers.etag.c_str());
448
+ should_download = true;
449
+ should_download_from_scratch = true;
450
+ } else if (!last_modified.empty() && last_modified != headers.last_modified) {
451
+ LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
452
+ last_modified.c_str(), headers.last_modified.c_str());
453
+ should_download = true;
454
+ should_download_from_scratch = true;
455
+ }
456
+ }
457
+
458
+ const bool accept_ranges_supported = !headers.accept_ranges.empty() && headers.accept_ranges != "none";
459
+ if (should_download) {
460
+ if (file_exists &&
461
+ !accept_ranges_supported) { // Resumable downloads not supported, delete and start again.
462
+ LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
463
+ if (remove(path.c_str()) != 0) {
464
+ LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
465
+ return false;
466
+ }
388
467
  }
389
- }
390
468
 
391
- // Set the output file
469
+ const std::string path_temporary = path + ".downloadInProgress";
470
+ if (should_download_from_scratch) {
471
+ if (std::filesystem::exists(path_temporary)) {
472
+ if (remove(path_temporary.c_str()) != 0) {
473
+ LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
474
+ return false;
475
+ }
476
+ }
392
477
 
393
- struct FILE_deleter {
394
- void operator()(FILE * f) const {
395
- fclose(f);
478
+ if (std::filesystem::exists(path)) {
479
+ if (remove(path.c_str()) != 0) {
480
+ LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
481
+ return false;
482
+ }
483
+ }
396
484
  }
397
- };
398
-
399
- std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
400
- if (!outfile) {
401
- LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
402
- return false;
403
- }
404
-
405
- typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
406
- auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
407
- return fwrite(data, size, nmemb, (FILE *)fd);
408
- };
409
- curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
410
- curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
411
- curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
412
485
 
413
- // display download progress
414
- curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
486
+ // Write the updated JSON metadata file.
487
+ metadata.update({
488
+ { "url", url },
489
+ { "etag", headers.etag },
490
+ { "lastModified", headers.last_modified }
491
+ });
492
+ write_file(metadata_path, metadata.dump(4));
493
+ LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
494
+
495
+ // start the download
496
+ LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
497
+ __func__, llama_download_hide_password_in_url(url).c_str(), path_temporary.c_str(),
498
+ headers.etag.c_str(), headers.last_modified.c_str());
499
+ const bool was_pull_successful = common_pull_file(curl.get(), path_temporary);
500
+ if (!was_pull_successful) {
501
+ if (i + 1 < max_attempts) {
502
+ const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
503
+ LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
504
+ std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
505
+ } else {
506
+ LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
507
+ }
415
508
 
416
- // helper function to hide password in URL
417
- auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
418
- std::size_t protocol_pos = url.find("://");
419
- if (protocol_pos == std::string::npos) {
420
- return url; // Malformed URL
509
+ continue;
421
510
  }
422
511
 
423
- std::size_t at_pos = url.find('@', protocol_pos + 3);
424
- if (at_pos == std::string::npos) {
425
- return url; // No password in URL
512
+ long http_code = 0;
513
+ curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
514
+ if (http_code < 200 || http_code >= 400) {
515
+ LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
516
+ return false;
426
517
  }
427
518
 
428
- return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
429
- };
430
-
431
- // start the download
432
- LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
433
- llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
434
- bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET");
435
- if (!was_perform_successful) {
436
- return false;
437
- }
438
-
439
- long http_code = 0;
440
- curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
441
- if (http_code < 200 || http_code >= 400) {
442
- LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
443
- return false;
519
+ if (rename(path_temporary.c_str(), path.c_str()) != 0) {
520
+ LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
521
+ return false;
522
+ }
523
+ } else {
524
+ LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
444
525
  }
445
526
 
446
- // Causes file to be closed explicitly here before we rename it.
447
- outfile.reset();
448
-
449
- // Write the updated JSON metadata file.
450
- metadata.update({
451
- {"url", url},
452
- {"etag", headers.etag},
453
- {"lastModified", headers.last_modified}
454
- });
455
- write_file(metadata_path, metadata.dump(4));
456
- LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
457
-
458
- if (rename(path_temporary.c_str(), path.c_str()) != 0) {
459
- LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
460
- return false;
461
- }
462
- } else {
463
- LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
527
+ break;
464
528
  }
465
529
 
466
530
  return true;
@@ -745,6 +809,124 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
745
809
 
746
810
  #endif // LLAMA_USE_CURL
747
811
 
812
+ //
813
+ // Docker registry functions
814
+ //
815
+
816
+ static std::string common_docker_get_token(const std::string & repo) {
817
+ std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
818
+
819
+ common_remote_params params;
820
+ auto res = common_remote_get_content(url, params);
821
+
822
+ if (res.first != 200) {
823
+ throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
824
+ }
825
+
826
+ std::string response_str(res.second.begin(), res.second.end());
827
+ nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
828
+
829
+ if (!response.contains("token")) {
830
+ throw std::runtime_error("Docker registry token response missing 'token' field");
831
+ }
832
+
833
+ return response["token"].get<std::string>();
834
+ }
835
+
836
+ static std::string common_docker_resolve_model(const std::string & docker) {
837
+ // Parse ai/smollm2:135M-Q4_0
838
+ size_t colon_pos = docker.find(':');
839
+ std::string repo, tag;
840
+ if (colon_pos != std::string::npos) {
841
+ repo = docker.substr(0, colon_pos);
842
+ tag = docker.substr(colon_pos + 1);
843
+ } else {
844
+ repo = docker;
845
+ tag = "latest";
846
+ }
847
+
848
+ // ai/ is the default
849
+ size_t slash_pos = docker.find('/');
850
+ if (slash_pos == std::string::npos) {
851
+ repo.insert(0, "ai/");
852
+ }
853
+
854
+ LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
855
+ try {
856
+ // --- helper: digest validation ---
857
+ auto validate_oci_digest = [](const std::string & digest) -> std::string {
858
+ // Expected: algo:hex ; start with sha256 (64 hex chars)
859
+ // You can extend this map if supporting other algorithms in future.
860
+ static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
861
+ std::smatch m;
862
+ if (!std::regex_match(digest, m, re)) {
863
+ throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
864
+ }
865
+ // normalize hex to lowercase
866
+ std::string normalized = digest;
867
+ std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
868
+ return std::tolower(c);
869
+ });
870
+ return normalized;
871
+ };
872
+
873
+ std::string token = common_docker_get_token(repo); // Get authentication token
874
+
875
+ // Get manifest
876
+ const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
877
+ std::string manifest_url = url_prefix + "/manifests/" + tag;
878
+ common_remote_params manifest_params;
879
+ manifest_params.headers.push_back("Authorization: Bearer " + token);
880
+ manifest_params.headers.push_back(
881
+ "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
882
+ auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
883
+ if (manifest_res.first != 200) {
884
+ throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
885
+ }
886
+
887
+ std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
888
+ nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
889
+ std::string gguf_digest; // Find the GGUF layer
890
+ if (manifest.contains("layers")) {
891
+ for (const auto & layer : manifest["layers"]) {
892
+ if (layer.contains("mediaType")) {
893
+ std::string media_type = layer["mediaType"].get<std::string>();
894
+ if (media_type == "application/vnd.docker.ai.gguf.v3" ||
895
+ media_type.find("gguf") != std::string::npos) {
896
+ gguf_digest = layer["digest"].get<std::string>();
897
+ break;
898
+ }
899
+ }
900
+ }
901
+ }
902
+
903
+ if (gguf_digest.empty()) {
904
+ throw std::runtime_error("No GGUF layer found in Docker manifest");
905
+ }
906
+
907
+ // Validate & normalize digest
908
+ gguf_digest = validate_oci_digest(gguf_digest);
909
+ LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
910
+
911
+ // Prepare local filename
912
+ std::string model_filename = repo;
913
+ std::replace(model_filename.begin(), model_filename.end(), '/', '_');
914
+ model_filename += "_" + tag + ".gguf";
915
+ std::string local_path = fs_get_cache_file(model_filename);
916
+
917
+ const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
918
+ if (!common_download_file_single(blob_url, local_path, token, false)) {
919
+ throw std::runtime_error("Failed to download Docker Model");
920
+ }
921
+
922
+ LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
923
+ return local_path;
924
+ } catch (const std::exception & e) {
925
+ LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
926
+ throw;
927
+ }
928
+ }
929
+
748
930
  //
749
931
  // utils
750
932
  //
@@ -795,7 +977,9 @@ static handle_model_result common_params_handle_model(
795
977
  handle_model_result result;
796
978
  // handle pre-fill default model path and url based on hf_repo and hf_file
797
979
  {
798
- if (!model.hf_repo.empty()) {
980
+ if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
981
+ model.path = common_docker_resolve_model(model.docker_repo);
982
+ } else if (!model.hf_repo.empty()) {
799
983
  // short-hand to avoid specifying --hf-file -> default it to --model
800
984
  if (model.hf_file.empty()) {
801
985
  if (model.path.empty()) {
@@ -1184,7 +1368,7 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
1184
1368
  } else {
1185
1369
  for (const auto & device : dev_names) {
1186
1370
  auto * dev = ggml_backend_dev_by_name(device.c_str());
1187
- if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
1371
+ if (!dev || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
1188
1372
  throw std::invalid_argument(string_format("invalid device: %s", device.c_str()));
1189
1373
  }
1190
1374
  devices.push_back(dev);
@@ -1194,7 +1378,7 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
1194
1378
  return devices;
1195
1379
  }
1196
1380
 
1197
- static void add_rpc_devices(std::string servers) {
1381
+ static void add_rpc_devices(const std::string & servers) {
1198
1382
  auto rpc_servers = string_split<std::string>(servers, ',');
1199
1383
  if (rpc_servers.empty()) {
1200
1384
  throw std::invalid_argument("no RPC servers specified");
@@ -1584,7 +1768,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
1584
1768
  [](common_params & params, const std::string & value) {
1585
1769
  params.system_prompt = value;
1586
1770
  }
1587
- ).set_examples({LLAMA_EXAMPLE_MAIN}));
1771
+ ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
1588
1772
  add_opt(common_arg(
1589
1773
  {"--no-perf"},
1590
1774
  string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),
@@ -2396,24 +2580,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2396
2580
  {"--list-devices"},
2397
2581
  "print list of available devices and exit",
2398
2582
  [](common_params &) {
2399
- std::vector<ggml_backend_dev_t> rpc_devices;
2400
- std::vector<ggml_backend_dev_t> all_devices;
2583
+ std::vector<ggml_backend_dev_t> devices;
2401
2584
  for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
2402
2585
  auto * dev = ggml_backend_dev_get(i);
2403
- if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
2404
- ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
2405
- if (ggml_backend_reg_name(reg) == std::string("RPC")) {
2406
- rpc_devices.push_back(dev);
2407
- } else {
2408
- all_devices.push_back(dev);
2409
- }
2586
+ if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
2587
+ devices.push_back(dev);
2410
2588
  }
2411
2589
  }
2412
- // insert RPC devices in front
2413
- all_devices.insert(all_devices.begin(), rpc_devices.begin(), rpc_devices.end());
2414
2590
  printf("Available devices:\n");
2415
- for (size_t i = 0; i < all_devices.size(); ++i) {
2416
- auto * dev = all_devices[i];
2591
+ for (auto * dev : devices) {
2417
2592
  size_t free, total;
2418
2593
  ggml_backend_dev_memory(dev, &free, &total);
2419
2594
  printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
@@ -2437,7 +2612,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2437
2612
  {"--cpu-moe", "-cmoe"},
2438
2613
  "keep all Mixture of Experts (MoE) weights in the CPU",
2439
2614
  [](common_params & params) {
2440
- params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
2615
+ params.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
2441
2616
  }
2442
2617
  ).set_env("LLAMA_ARG_CPU_MOE"));
2443
2618
  add_opt(common_arg(
@@ -2450,7 +2625,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2450
2625
  for (int i = 0; i < value; ++i) {
2451
2626
  // keep strings alive and avoid leaking memory by storing them in a static vector
2452
2627
  static std::list<std::string> buft_overrides;
2453
- buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
2628
+ buft_overrides.push_back(llm_ffn_exps_block_regex(i));
2454
2629
  params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), ggml_backend_cpu_buffer_type()});
2455
2630
  }
2456
2631
  }
@@ -2459,7 +2634,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2459
2634
  {"--cpu-moe-draft", "-cmoed"},
2460
2635
  "keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
2461
2636
  [](common_params & params) {
2462
- params.speculative.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
2637
+ params.speculative.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
2463
2638
  }
2464
2639
  ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
2465
2640
  add_opt(common_arg(
@@ -2471,7 +2646,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2471
2646
  }
2472
2647
  for (int i = 0; i < value; ++i) {
2473
2648
  static std::list<std::string> buft_overrides_draft;
2474
- buft_overrides_draft.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
2649
+ buft_overrides_draft.push_back(llm_ffn_exps_block_regex(i));
2475
2650
  params.speculative.tensor_buft_overrides.push_back({buft_overrides_draft.back().c_str(), ggml_backend_cpu_buffer_type()});
2476
2651
  }
2477
2652
  }
@@ -2636,6 +2811,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2636
2811
  params.model.url = value;
2637
2812
  }
2638
2813
  ).set_env("LLAMA_ARG_MODEL_URL"));
2814
+ add_opt(common_arg(
2815
+ { "-dr", "--docker-repo" }, "[<repo>/]<model>[:quant]",
2816
+ "Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.\n"
2817
+ "example: gemma3\n"
2818
+ "(default: unused)",
2819
+ [](common_params & params, const std::string & value) {
2820
+ params.model.docker_repo = value;
2821
+ }
2822
+ ).set_env("LLAMA_ARG_DOCKER_REPO"));
2639
2823
  add_opt(common_arg(
2640
2824
  {"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
2641
2825
  "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"