durable_huggingface_hub 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. checksums.yaml +7 -0
  2. data/.editorconfig +29 -0
  3. data/.rubocop.yml +108 -0
  4. data/CHANGELOG.md +127 -0
  5. data/README.md +547 -0
  6. data/Rakefile +106 -0
  7. data/devenv.lock +171 -0
  8. data/devenv.nix +15 -0
  9. data/devenv.yaml +8 -0
  10. data/huggingface_hub.gemspec +63 -0
  11. data/lib/durable_huggingface_hub/authentication.rb +245 -0
  12. data/lib/durable_huggingface_hub/cache.rb +508 -0
  13. data/lib/durable_huggingface_hub/configuration.rb +191 -0
  14. data/lib/durable_huggingface_hub/constants.rb +145 -0
  15. data/lib/durable_huggingface_hub/errors.rb +412 -0
  16. data/lib/durable_huggingface_hub/file_download.rb +831 -0
  17. data/lib/durable_huggingface_hub/hf_api.rb +1278 -0
  18. data/lib/durable_huggingface_hub/repo_card.rb +430 -0
  19. data/lib/durable_huggingface_hub/types/cache_info.rb +298 -0
  20. data/lib/durable_huggingface_hub/types/commit_info.rb +149 -0
  21. data/lib/durable_huggingface_hub/types/dataset_info.rb +158 -0
  22. data/lib/durable_huggingface_hub/types/model_info.rb +154 -0
  23. data/lib/durable_huggingface_hub/types/space_info.rb +158 -0
  24. data/lib/durable_huggingface_hub/types/user.rb +179 -0
  25. data/lib/durable_huggingface_hub/types.rb +205 -0
  26. data/lib/durable_huggingface_hub/utils/auth.rb +174 -0
  27. data/lib/durable_huggingface_hub/utils/headers.rb +220 -0
  28. data/lib/durable_huggingface_hub/utils/http.rb +329 -0
  29. data/lib/durable_huggingface_hub/utils/paths.rb +230 -0
  30. data/lib/durable_huggingface_hub/utils/progress.rb +217 -0
  31. data/lib/durable_huggingface_hub/utils/retry.rb +165 -0
  32. data/lib/durable_huggingface_hub/utils/validators.rb +236 -0
  33. data/lib/durable_huggingface_hub/version.rb +8 -0
  34. data/lib/huggingface_hub.rb +205 -0
  35. metadata +334 -0
@@ -0,0 +1,831 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "pathname"
4
+ require "digest"
5
+ require "fileutils"
6
+ require "json"
7
+ require_relative "utils/progress"
8
+ require_relative "hf_api"
9
+ require_relative "errors"
10
+
11
+ module DurableHuggingfaceHub
12
+ # File download functionality with caching and ETag support.
13
+ #
14
+ # This module provides utilities for downloading files from the HuggingFace Hub
15
+ # with intelligent caching, resume support, and validation using ETags.
16
+ #
17
+ # @example Download a model file
18
+ # path = DurableHuggingfaceHub::FileDownload.hf_hub_download(
19
+ # repo_id: "bert-base-uncased",
20
+ # filename: "config.json"
21
+ # )
22
+ # config = JSON.parse(File.read(path))
23
+ #
24
+ # @example Download with custom cache directory
25
+ # path = DurableHuggingfaceHub::FileDownload.hf_hub_download(
26
+ # repo_id: "gpt2",
27
+ # filename: "pytorch_model.bin",
28
+ # cache_dir: "/custom/cache"
29
+ # )
30
+ module FileDownload
31
+ # Default cache directory location
32
+ DEFAULT_CACHE_DIR = Pathname.new(Dir.home).join(".cache", "huggingface", "hub")
33
+
34
+ # Metadata file name for cache entries
35
+ METADATA_FILENAME = ".metadata.json"
36
+
37
+ # Lock file suffix for atomic operations
38
+ LOCK_SUFFIX = ".lock"
39
+
40
+ # Downloads a file from the HuggingFace Hub with caching.
41
+ #
42
+ # This method downloads a file from a HuggingFace Hub repository and caches
43
+ # it locally. It uses ETags to avoid re-downloading unchanged files and
44
+ # supports atomic operations to prevent cache corruption.
45
+ #
46
+ # @param repo_id [String] Repository ID (e.g., "bert-base-uncased")
47
+ # @param filename [String] Path to file in repository
48
+ # @param repo_type [String] Type of repository ("model", "dataset", or "space")
49
+ # @param revision [String, nil] Git revision (branch, tag, or commit SHA)
50
+ # @param cache_dir [String, Pathname, nil] Custom cache directory
51
+ # @param force_download [Boolean] Force re-download even if cached
52
+ # @param token [String, nil] HuggingFace API token
53
+ # @param local_files_only [Boolean] Only use cached files, don't download
54
+ # @param headers [Hash, nil] Additional HTTP headers
55
+ # @param progress [Proc, nil] Progress callback (receives current, total, percentage)
56
+ #
57
+ # @return [Pathname] Path to the downloaded (or cached) file
58
+ #
59
+ # @raise [RepositoryNotFoundError] If repository doesn't exist
60
+ # @raise [EntryNotFoundError] If file doesn't exist in repository
61
+ # @raise [LocalEntryNotFoundError] If local_files_only=true and file not cached
62
+ # @raise [ValidationError] If parameters are invalid
63
+ #
64
+ # @example Basic download
65
+ # path = FileDownload.hf_hub_download(
66
+ # repo_id: "bert-base-uncased",
67
+ # filename: "config.json"
68
+ # )
69
+ #
70
+ # @example Download specific revision
71
+ # path = FileDownload.hf_hub_download(
72
+ # repo_id: "gpt2",
73
+ # filename: "pytorch_model.bin",
74
+ # revision: "main"
75
+ # )
76
+ #
77
+ # @example Force re-download
78
+ # path = FileDownload.hf_hub_download(
79
+ # repo_id: "bert-base-uncased",
80
+ # filename: "config.json",
81
+ # force_download: true
82
+ # )
83
+ def self.hf_hub_download(
84
+ repo_id:,
85
+ filename:,
86
+ repo_type: "model",
87
+ revision: nil,
88
+ cache_dir: nil,
89
+ force_download: false,
90
+ token: nil,
91
+ local_files_only: false,
92
+ headers: nil,
93
+ progress: nil
94
+ )
95
+ # Validate inputs
96
+ repo_id = Utils::Validators.validate_repo_id(repo_id)
97
+ filename = Utils::Validators.validate_filename(filename)
98
+ repo_type = Utils::Validators.validate_repo_type(repo_type)
99
+ revision = Utils::Validators.validate_revision(revision) if revision
100
+
101
+ # Get cache directory
102
+ cache_dir = resolve_cache_dir(cache_dir)
103
+
104
+ # Build storage paths
105
+ storage_folder = get_storage_folder(repo_id, repo_type: repo_type, cache_dir: cache_dir)
106
+ revision ||= "main"
107
+
108
+ # Check if we can use local files only
109
+ if local_files_only
110
+ cached_path = find_cached_file(storage_folder, filename, revision)
111
+ if cached_path
112
+ return cached_path
113
+ else
114
+ raise LocalEntryNotFoundError.new(
115
+ "File #{filename} not found in local cache for #{repo_id}@#{revision}. " \
116
+ "Cannot download because local_files_only=true"
117
+ )
118
+ end
119
+ end
120
+
121
+ # Get token for authentication
122
+ token = Utils::Auth.get_token(token: token)
123
+
124
+ # Download or retrieve from cache
125
+ download_file(
126
+ repo_id: repo_id,
127
+ filename: filename,
128
+ repo_type: repo_type,
129
+ revision: revision,
130
+ storage_folder: storage_folder,
131
+ force_download: force_download,
132
+ token: token,
133
+ headers: headers,
134
+ progress: progress
135
+ )
136
+ end
137
+
138
+ # Downloads an entire repository snapshot from the HuggingFace Hub with caching.
139
+ #
140
+ # This method downloads all files from a HuggingFace Hub repository for a given
141
+ # revision and stores them in a local cache directory. It leverages `hf_hub_download`
142
+ # for individual file downloads and supports filtering by patterns.
143
+ #
144
+ # The method implements robust offline fallback: if the Hub is unavailable or network
145
+ # is down, it will try to use locally cached files. It properly handles commit hash
146
+ # resolution for branches and tags.
147
+ #
148
+ # @param repo_id [String] Repository ID (e.g., "bert-base-uncased")
149
+ # @param repo_type [String, Symbol] Type of repository ("model", "dataset", or "space")
150
+ # @param revision [String, nil] Git revision (branch, tag, or commit SHA). Defaults to "main"
151
+ # @param cache_dir [String, Pathname, nil] Custom cache directory
152
+ # @param local_dir [String, Pathname, nil] Custom directory to copy the snapshot to.
153
+ # If nil, the snapshot will remain in the cache.
154
+ # @param force_download [Boolean] Force re-download even if cached
155
+ # @param token [String, nil] HuggingFace API token
156
+ # @param local_files_only [Boolean] Only use cached files, don't download
157
+ # @param allow_patterns [Array<String>, String, nil] Glob patterns to include (e.g., "*.json", ["*.py", "*.md"])
158
+ # @param ignore_patterns [Array<String>, String, nil] Glob patterns to exclude (e.g., "*.bin", ["*.safetensors"])
159
+ # @param max_workers [Integer] Number of concurrent downloads (default: 8)
160
+ # @param progress [Proc, nil] Progress callback (receives current, total, percentage)
161
+ #
162
+ # @return [Pathname] Path to the downloaded (or cached) snapshot directory
163
+ #
164
+ # @raise [RepositoryNotFoundError] If repository doesn't exist
165
+ # @raise [RevisionNotFoundError] If revision doesn't exist
166
+ # @raise [LocalEntryNotFoundError] If local_files_only=true and files not cached
167
+ # @raise [ValidationError] If parameters are invalid
168
+ #
169
+ # @example Download entire model repository
170
+ # local_dir = FileDownload.snapshot_download(
171
+ # repo_id: "gpt2",
172
+ # revision: "main"
173
+ # )
174
+ #
175
+ # @example Download only specific file patterns
176
+ # filtered_dir = FileDownload.snapshot_download(
177
+ # repo_id: "bert-base-uncased",
178
+ # allow_patterns: ["*.json", "*.txt"],
179
+ # ignore_patterns: ["*.bin"]
180
+ # )
181
+ #
182
+ # @example Download with parallel downloads
183
+ # snapshot = FileDownload.snapshot_download(
184
+ # repo_id: "bert-base-uncased",
185
+ # max_workers: 16
186
+ # )
187
+ def self.snapshot_download(
188
+ repo_id:,
189
+ repo_type: "model",
190
+ revision: nil,
191
+ cache_dir: nil,
192
+ local_dir: nil,
193
+ force_download: false,
194
+ token: nil,
195
+ local_files_only: false,
196
+ allow_patterns: nil,
197
+ ignore_patterns: nil,
198
+ max_workers: 8,
199
+ progress: nil
200
+ )
201
+ # Validate inputs
202
+ repo_id = Utils::Validators.validate_repo_id(repo_id)
203
+ repo_type = Utils::Validators.validate_repo_type(repo_type)
204
+ revision = Utils::Validators.validate_revision(revision) if revision
205
+ revision ||= "main"
206
+
207
+ # Get cache directory and storage folder
208
+ cache_dir = resolve_cache_dir(cache_dir)
209
+ storage_folder = get_storage_folder(repo_id, repo_type: repo_type, cache_dir: cache_dir)
210
+
211
+ # Get token for authentication
212
+ token = Utils::Auth.get_token(token: token)
213
+
214
+ # Try to fetch repository info from Hub
215
+ repo_info = nil
216
+ api_call_error = nil
217
+
218
+ unless local_files_only
219
+ begin
220
+ # Initialize HfApi client
221
+ api = HfApi.new(token: token)
222
+ repo_info = api.repo_info(repo_id, repo_type: repo_type, revision: revision)
223
+ rescue StandardError => e
224
+ # Store error but continue - we might be able to use cached files
225
+ api_call_error = e
226
+ end
227
+ end
228
+
229
+ # If we couldn't get repo_info, try to use cached files
230
+ if repo_info.nil?
231
+ # Try to resolve commit hash from revision
232
+ commit_hash = nil
233
+
234
+ # Check if revision is already a commit hash
235
+ if revision.match?(/^[0-9a-f]{40}$/)
236
+ commit_hash = revision
237
+ else
238
+ # Try to read commit hash from refs
239
+ ref_file = storage_folder.join("refs", revision)
240
+ if ref_file.exist?
241
+ commit_hash = ref_file.read.strip
242
+ end
243
+ end
244
+
245
+ # Try to locate snapshot folder for this commit hash
246
+ if commit_hash && local_dir.nil?
247
+ snapshot_folder = storage_folder.join("snapshots", commit_hash)
248
+ if snapshot_folder.exist? && snapshot_folder.directory?
249
+ # Snapshot folder exists => return it
250
+ return snapshot_folder
251
+ end
252
+ end
253
+
254
+ # If local_dir is specified and exists, return it
255
+ if local_dir
256
+ local_dir_path = Utils::Paths.expand_path(local_dir)
257
+ if local_dir_path.exist? && local_dir_path.directory? && !local_dir_path.children.empty?
258
+ warn "Returning existing local_dir #{local_dir_path} as remote repo cannot be accessed"
259
+ return local_dir_path
260
+ end
261
+ end
262
+
263
+ # Could not find cached files - raise appropriate error
264
+ if local_files_only
265
+ raise LocalEntryNotFoundError.new(
266
+ "Cannot find an appropriate cached snapshot folder for #{repo_id}@#{revision}. " \
267
+ "To enable downloads, set local_files_only=false"
268
+ )
269
+ elsif api_call_error.is_a?(RepositoryNotFoundError) || api_call_error.is_a?(RevisionNotFoundError)
270
+ raise api_call_error
271
+ else
272
+ raise LocalEntryNotFoundError.new(
273
+ "An error occurred while trying to locate files on the Hub, and we cannot find " \
274
+ "the appropriate snapshot folder for #{repo_id}@#{revision} in the local cache. " \
275
+ "Please check your internet connection and try again. Error: #{api_call_error&.message}"
276
+ )
277
+ end
278
+ end
279
+
280
+ # At this point, we have repo_info with a valid commit hash
281
+ commit_hash = repo_info.sha
282
+ raise DurableHuggingfaceHubError, "Repo info must have a commit SHA" unless commit_hash
283
+
284
+ # Determine snapshot folder
285
+ snapshot_folder = storage_folder.join("snapshots", commit_hash)
286
+
287
+ # Store ref if revision is not a commit hash
288
+ if revision != commit_hash
289
+ update_refs(storage_folder, revision, commit_hash)
290
+ end
291
+
292
+ # Get list of files from repo_info
293
+ all_files = if repo_info.respond_to?(:siblings) && repo_info.siblings
294
+ repo_info.siblings.map { |sibling| sibling[:rfilename] || sibling["rfilename"] }.compact
295
+ else
296
+ # Fallback to API call if siblings not available
297
+ api.list_repo_files(repo_id: repo_id, repo_type: repo_type, revision: commit_hash)
298
+ end
299
+
300
+ # Filter files based on allow_patterns and ignore_patterns
301
+ filtered_files = Utils::Paths.filter_repo_objects(all_files, allow_patterns: allow_patterns, ignore_patterns: ignore_patterns)
302
+
303
+ # Download files (with parallelization if max_workers > 1)
304
+ if max_workers > 1
305
+ download_files_parallel(
306
+ repo_id: repo_id,
307
+ files: filtered_files,
308
+ repo_type: repo_type,
309
+ revision: commit_hash,
310
+ cache_dir: cache_dir,
311
+ force_download: force_download,
312
+ token: token,
313
+ max_workers: max_workers,
314
+ progress: progress
315
+ )
316
+ else
317
+ # Sequential download
318
+ filtered_files.each do |filename|
319
+ hf_hub_download(
320
+ repo_id: repo_id,
321
+ filename: filename,
322
+ repo_type: repo_type,
323
+ revision: commit_hash,
324
+ cache_dir: cache_dir,
325
+ force_download: force_download,
326
+ token: token,
327
+ local_files_only: false,
328
+ progress: progress
329
+ )
330
+ end
331
+ end
332
+
333
+ # If local_dir is specified, copy the snapshot there
334
+ if local_dir
335
+ local_dir_path = Utils::Paths.expand_path(local_dir)
336
+ copy_snapshot_to_local_dir(snapshot_folder, local_dir_path)
337
+ return local_dir_path.realpath
338
+ end
339
+
340
+ snapshot_folder
341
+ end
342
+
343
+ # Gets the cache directory for a repository.
344
+ #
345
+ # @param repo_id [String] Repository ID
346
+ # @param repo_type [String] Type of repository
347
+ # @param cache_dir [String, Pathname, nil] Custom cache directory
348
+ # @return [Pathname] Storage folder path
349
+ def self.get_storage_folder(repo_id, repo_type: "model", cache_dir: nil)
350
+ cache_dir = resolve_cache_dir(cache_dir)
351
+
352
+ # Create a unique folder name based on repo_id and type
353
+ # Format: models--namespace--name or models--name
354
+ repo_id_parts = repo_id.split("/")
355
+ if repo_id_parts.length == 2
356
+ folder_name = "#{repo_type}s--#{repo_id_parts[0]}--#{repo_id_parts[1]}"
357
+ else
358
+ folder_name = "#{repo_type}s--#{repo_id}"
359
+ end
360
+
361
+ cache_dir.join(folder_name)
362
+ end
363
+
364
+ # Resolves the cache directory to use.
365
+ #
366
+ # @param cache_dir [String, Pathname, nil] Custom cache directory
367
+ # @return [Pathname] Resolved cache directory
368
+ def self.resolve_cache_dir(cache_dir)
369
+ if cache_dir
370
+ Utils::Paths.expand_path(cache_dir)
371
+ elsif ENV["HF_HOME"]
372
+ Pathname.new(ENV["HF_HOME"]).join("hub")
373
+ elsif ENV["HUGGINGFACE_HUB_CACHE"]
374
+ Pathname.new(ENV["HUGGINGFACE_HUB_CACHE"])
375
+ else
376
+ DEFAULT_CACHE_DIR
377
+ end
378
+ end
379
+
380
+ # Finds a cached file for a specific revision.
381
+ #
382
+ # @param storage_folder [Pathname] Repository storage folder
383
+ # @param filename [String] File path in repository
384
+ # @param revision [String] Git revision
385
+ # @return [Pathname, nil] Path to cached file or nil if not found
386
+ def self.find_cached_file(storage_folder, filename, revision)
387
+ # Look for snapshot folder for this revision
388
+ snapshots_folder = storage_folder.join("snapshots")
389
+ return nil unless snapshots_folder.exist?
390
+
391
+ # Try to find by revision folder
392
+ revision_folder = snapshots_folder.join(revision)
393
+ if revision_folder.exist?
394
+ file_path = revision_folder.join(filename)
395
+ return file_path if file_path.exist?
396
+ end
397
+
398
+ # Try to find in refs folder
399
+ refs_folder = storage_folder.join("refs")
400
+ if refs_folder.exist?
401
+ ref_file = refs_folder.join(revision)
402
+ if ref_file.exist?
403
+ commit_hash = ref_file.read.strip
404
+ commit_folder = snapshots_folder.join(commit_hash)
405
+ if commit_folder.exist?
406
+ file_path = commit_folder.join(filename)
407
+ return file_path if file_path.exist?
408
+ end
409
+ end
410
+ end
411
+
412
+ nil
413
+ end
414
+
415
+ # Downloads a file and stores it in the cache.
416
+ #
417
+ # @param repo_id [String] Repository ID
418
+ # @param filename [String] File path in repository
419
+ # @param repo_type [String] Type of repository
420
+ # @param revision [String] Git revision
421
+ # @param storage_folder [Pathname] Repository storage folder
422
+ # @param force_download [Boolean] Force re-download
423
+ # @param token [String, nil] HuggingFace API token
424
+ # @param headers [Hash, nil] Additional HTTP headers
425
+ # @param progress [Proc, nil] Progress callback
426
+ # @return [Pathname] Path to downloaded file
427
+ def self.download_file(
428
+ repo_id:,
429
+ filename:,
430
+ repo_type:,
431
+ revision:,
432
+ storage_folder:,
433
+ force_download:,
434
+ token:,
435
+ headers:,
436
+ progress:
437
+ )
438
+ # Create HTTP client
439
+ client = Utils::HttpClient.new(token: token, headers: headers)
440
+
441
+ # Build URL for file
442
+ url_path = "/#{repo_type}s/#{repo_id}/resolve/#{revision}/#{filename}"
443
+
444
+ # Get metadata about the file (including ETag and commit hash)
445
+ metadata = get_file_metadata(client, url_path)
446
+
447
+ # Determine final storage location
448
+ commit_hash = metadata[:commit_hash] || revision
449
+ blob_path = storage_folder.join("blobs", metadata[:etag])
450
+ snapshot_path = storage_folder.join("snapshots", commit_hash, filename)
451
+
452
+ # Check if we already have this file (by ETag or snapshot file)
453
+ unless force_download
454
+ if blob_path.exist? && verify_blob(blob_path, metadata[:etag])
455
+ # File exists in blob storage, create symlink if needed
456
+ ensure_snapshot_link(blob_path, snapshot_path)
457
+ update_refs(storage_folder, revision, commit_hash)
458
+ return snapshot_path
459
+ elsif snapshot_path.exist?
460
+ # File exists in snapshot, assume it's valid
461
+ update_refs(storage_folder, revision, commit_hash)
462
+ return snapshot_path
463
+ end
464
+ end
465
+
466
+ # Download the file to blob storage
467
+ download_to_blob(client, url_path, blob_path, metadata, progress)
468
+
469
+ # Create snapshot symlink
470
+ ensure_snapshot_link(blob_path, snapshot_path)
471
+
472
+ # Update refs
473
+ update_refs(storage_folder, revision, commit_hash)
474
+
475
+ snapshot_path
476
+ end
477
+
478
+ # Gets metadata about a file from the Hub.
479
+ #
480
+ # @param client [Utils::HttpClient] HTTP client
481
+ # @param url_path [String] URL path to file
482
+ # @return [Hash] Metadata including :etag, :size, :commit_hash
483
+ def self.get_file_metadata(client, url_path)
484
+ response = client.head(url_path)
485
+
486
+ # Extract metadata from headers (response is now a Faraday::Response object)
487
+ headers = response.headers
488
+ {
489
+ etag: extract_etag(headers["etag"] || headers["x-linked-etag"]),
490
+ size: headers["x-linked-size"]&.to_i,
491
+ commit_hash: headers["x-repo-commit"]
492
+ }
493
+ end
494
+
495
+ # Extracts clean ETag from header value.
496
+ #
497
+ # @param etag [String, nil] Raw ETag header value
498
+ # @return [String, nil] Cleaned ETag
499
+ def self.extract_etag(etag)
500
+ return nil unless etag
501
+
502
+ # Remove quotes and W/ prefix
503
+ etag = etag.gsub(/^W\//, "").gsub(/^"/, "").gsub(/"$/, "")
504
+ etag.empty? ? nil : etag
505
+ end
506
+
507
+ # Verifies a blob file matches the expected ETag.
508
+ #
509
+ # @param blob_path [Pathname] Path to blob file
510
+ # @param etag [String] Expected ETag
511
+ # @return [Boolean] True if blob is valid
512
+ def self.verify_blob(blob_path, etag)
513
+ return false unless blob_path.exist?
514
+
515
+ # First check filename matches ETag (fast check)
516
+ return false unless blob_path.basename.to_s == etag
517
+
518
+ # For more robust verification, we could compute the actual ETag
519
+ # from file content, but for now we trust the filename-based approach
520
+ # used by HuggingFace Hub
521
+ true
522
+ end
523
+
524
+ # Downloads a file to blob storage.
525
+ #
526
+ # @param client [Utils::HttpClient] HTTP client
527
+ # @param url_path [String] URL path to file
528
+ # @param blob_path [Pathname] Destination blob path
529
+ # @param metadata [Hash] File metadata
530
+ # @param progress [Proc, nil] Progress callback
531
+ def self.download_to_blob(client, url_path, blob_path, metadata, progress)
532
+ # Ensure blobs directory exists
533
+ blob_path.dirname.mkpath
534
+
535
+ # Download to temporary file first (atomic operation)
536
+ temp_path = Pathname.new("#{blob_path}.tmp.#{Process.pid}")
537
+
538
+ # Create progress tracker
539
+ progress_tracker = if progress
540
+ Utils::Progress.new(total: metadata[:size], callback: progress)
541
+ else
542
+ Utils::NullProgress.new
543
+ end
544
+
545
+ begin
546
+ # Download file
547
+ response = client.request(:get, url_path) do |req|
548
+ req.options.on_data = proc do |chunk, _overall_received_bytes, _env|
549
+ File.open(temp_path, "ab") { |f| f.write(chunk) }
550
+ progress_tracker.update(chunk.bytesize)
551
+ end
552
+ end
553
+
554
+ # Verify download
555
+ unless temp_path.exist? && temp_path.size.positive?
556
+ raise DurableHuggingfaceHubError, "Download failed: file is empty or missing"
557
+ end
558
+
559
+ # Move to final location atomically
560
+ FileUtils.mv(temp_path, blob_path)
561
+
562
+ # Mark progress as finished
563
+ progress_tracker.finish
564
+
565
+ # Write metadata
566
+ write_blob_metadata(blob_path, metadata)
567
+ ensure
568
+ # Clean up temp file if it still exists
569
+ temp_path.unlink if temp_path.exist?
570
+ end
571
+ end
572
+
573
+ # Writes metadata for a blob file.
574
+ #
575
+ # @param blob_path [Pathname] Path to blob file
576
+ # @param metadata [Hash] Metadata to write
577
+ def self.write_blob_metadata(blob_path, metadata)
578
+ metadata_path = Pathname.new("#{blob_path}#{METADATA_FILENAME}")
579
+ metadata_path.write(JSON.pretty_generate(metadata))
580
+ end
581
+
582
+ # Ensures a symlink exists from snapshot to blob.
583
+ #
584
+ # @param blob_path [Pathname] Source blob path
585
+ # @param snapshot_path [Pathname] Destination snapshot path
586
+ def self.ensure_snapshot_link(blob_path, snapshot_path)
587
+ # Create snapshot directory if needed
588
+ snapshot_path.dirname.mkpath
589
+
590
+ # Remove existing file/link if present
591
+ snapshot_path.unlink if snapshot_path.exist? || snapshot_path.symlink?
592
+
593
+ # Create relative symlink
594
+ relative_blob_path = blob_path.relative_path_from(snapshot_path.dirname)
595
+ snapshot_path.make_symlink(relative_blob_path)
596
+ rescue NotImplementedError
597
+ # System doesn't support symlinks, copy instead
598
+ FileUtils.cp(blob_path, snapshot_path)
599
+ end
600
+
601
+ # Updates refs to point to the latest commit hash.
602
+ #
603
+ # @param storage_folder [Pathname] Repository storage folder
604
+ # @param revision [String] Revision name (branch/tag)
605
+ # @param commit_hash [String] Commit hash
606
+ def self.update_refs(storage_folder, revision, commit_hash)
607
+ return if revision == commit_hash # Don't create ref for commit hashes
608
+
609
+ refs_folder = storage_folder.join("refs")
610
+ refs_folder.mkpath
611
+
612
+ ref_file = refs_folder.join(revision)
613
+ ref_file.write(commit_hash)
614
+ end
615
+
616
+ # Filters files based on glob patterns.
617
+ #
618
+ # @param files [Array<String>] List of file paths
619
+ # @param allow_patterns [Array<String>, String, nil] Glob patterns to include
620
+ # @param ignore_patterns [Array<String>, String, nil] Glob patterns to exclude
621
+ # @return [Array<String>] Filtered list of files
622
+ def self.filter_repo_files(files, allow_patterns: nil, ignore_patterns: nil)
623
+ filtered = files
624
+
625
+ # Apply allow_patterns if specified
626
+ if allow_patterns
627
+ patterns = Array(allow_patterns)
628
+ filtered = filtered.select do |filename|
629
+ patterns.any? { |pattern| File.fnmatch(pattern, filename, File::FNM_PATHNAME) }
630
+ end
631
+ end
632
+
633
+ # Apply ignore_patterns if specified
634
+ if ignore_patterns
635
+ patterns = Array(ignore_patterns)
636
+ filtered = filtered.reject do |filename|
637
+ patterns.any? { |pattern| File.fnmatch(pattern, filename, File::FNM_PATHNAME) }
638
+ end
639
+ end
640
+
641
+ filtered
642
+ end
643
+
644
+ # Downloads multiple files in parallel using threads.
645
+ #
646
+ # @param repo_id [String] Repository ID
647
+ # @param files [Array<String>] List of files to download
648
+ # @param repo_type [String] Repository type
649
+ # @param revision [String] Git revision
650
+ # @param cache_dir [Pathname] Cache directory
651
+ # @param force_download [Boolean] Force re-download
652
+ # @param token [String, nil] Authentication token
653
+ # @param max_workers [Integer] Number of concurrent threads
654
+ # @param progress [Proc, nil] Progress callback
655
+ def self.download_files_parallel(
656
+ repo_id:,
657
+ files:,
658
+ repo_type:,
659
+ revision:,
660
+ cache_dir:,
661
+ force_download:,
662
+ token:,
663
+ max_workers:,
664
+ progress:
665
+ )
666
+ require "thread"
667
+
668
+ # Create a queue of files to download
669
+ queue = Queue.new
670
+ files.each { |file| queue << file }
671
+
672
+ # Track completed downloads
673
+ completed = 0
674
+ total = files.length
675
+ mutex = Mutex.new
676
+
677
+ # Create worker threads
678
+ threads = Array.new([max_workers, files.length].min) do
679
+ Thread.new do
680
+ loop do
681
+ file = begin
682
+ queue.pop(true)
683
+ rescue ThreadError
684
+ break # Queue is empty
685
+ end
686
+
687
+ begin
688
+ hf_hub_download(
689
+ repo_id: repo_id,
690
+ filename: file,
691
+ repo_type: repo_type,
692
+ revision: revision,
693
+ cache_dir: cache_dir,
694
+ force_download: force_download,
695
+ token: token,
696
+ local_files_only: false,
697
+ progress: nil # Individual file progress not supported in parallel mode
698
+ )
699
+
700
+ mutex.synchronize do
701
+ completed += 1
702
+ progress&.call(completed, total, (completed.to_f / total * 100).round(2)) if progress
703
+ end
704
+ rescue => e
705
+ warn "Failed to download #{file}: #{e.message}"
706
+ # Continue with other files
707
+ end
708
+ end
709
+ end
710
+ end
711
+
712
+ # Wait for all threads to complete
713
+ threads.each(&:join)
714
+ end
715
+
716
+ # Copies snapshot directory to local directory.
717
+ #
718
+ # @param snapshot_folder [Pathname] Source snapshot folder
719
+ # @param local_dir_path [Pathname] Destination local directory
720
+ def self.copy_snapshot_to_local_dir(snapshot_folder, local_dir_path)
721
+ return unless snapshot_folder.exist?
722
+
723
+ FileUtils.mkdir_p(local_dir_path)
724
+
725
+ # Copy all files and directories from snapshot to local_dir
726
+ snapshot_folder.children.each do |entry|
727
+ dest = local_dir_path.join(entry.basename)
728
+
729
+ if entry.symlink?
730
+ # For symlinks, copy the actual file content
731
+ target = entry.readlink
732
+ target = entry.dirname.join(target) unless target.absolute?
733
+
734
+ if target.file?
735
+ FileUtils.cp(target, dest)
736
+ end
737
+ elsif entry.directory?
738
+ FileUtils.cp_r(entry, dest)
739
+ elsif entry.file?
740
+ FileUtils.cp(entry, dest)
741
+ end
742
+ end
743
+ end
744
+
745
+ # Try to load a file from cache without downloading.
746
+ #
747
+ # This utility function checks if a file is available in the local cache
748
+ # and returns its path if found. Unlike `hf_hub_download` with `local_files_only=true`,
749
+ # this method returns `nil` instead of raising an error when the file is not cached.
750
+ #
751
+ # @param repo_id [String] Repository ID
752
+ # @param filename [String] File path in repository
753
+ # @param repo_type [String] Type of repository
754
+ # @param revision [String] Git revision
755
+ # @param cache_dir [String, Pathname, nil] Custom cache directory
756
+ # @return [Pathname, nil] Path to cached file, or nil if not found
757
+ #
758
+ # @example Check if file is cached
759
+ # path = FileDownload.try_to_load_from_cache(
760
+ # repo_id: "bert-base-uncased",
761
+ # filename: "config.json",
762
+ # revision: "main"
763
+ # )
764
+ # if path
765
+ # puts "File is cached at: #{path}"
766
+ # else
767
+ # puts "File not in cache"
768
+ # end
769
+ def self.try_to_load_from_cache(
770
+ repo_id:,
771
+ filename:,
772
+ repo_type: "model",
773
+ revision: "main",
774
+ cache_dir: nil
775
+ )
776
+ begin
777
+ hf_hub_download(
778
+ repo_id: repo_id,
779
+ filename: filename,
780
+ repo_type: repo_type,
781
+ revision: revision,
782
+ cache_dir: cache_dir,
783
+ local_files_only: true
784
+ )
785
+ rescue LocalEntryNotFoundError
786
+ nil
787
+ end
788
+ end
789
+
790
+ # Generate the HuggingFace Hub URL for a file in a repository.
791
+ #
792
+ # @param repo_id [String] Repository ID
793
+ # @param filename [String] File path in repository
794
+ # @param repo_type [String] Type of repository
795
+ # @param revision [String] Git revision (defaults to "main")
796
+ # @param endpoint [String, nil] Custom endpoint URL
797
+ # @return [String] Full URL to the file on HuggingFace Hub
798
+ #
799
+ # @example Generate URL for a model file
800
+ # url = FileDownload.hf_hub_url(
801
+ # repo_id: "bert-base-uncased",
802
+ # filename: "config.json"
803
+ # )
804
+ # # => "https://huggingface.co/bert-base-uncased/resolve/main/config.json"
805
+ #
806
+ # @example Generate URL for a dataset file
807
+ # url = FileDownload.hf_hub_url(
808
+ # repo_id: "squad",
809
+ # filename: "train.json",
810
+ # repo_type: "dataset",
811
+ # revision: "v1.0"
812
+ # )
813
+ def self.hf_hub_url(
814
+ repo_id:,
815
+ filename:,
816
+ repo_type: "model",
817
+ revision: "main",
818
+ endpoint: nil
819
+ )
820
+ repo_id = Utils::Validators.validate_repo_id(repo_id)
821
+ filename = Utils::Validators.validate_filename(filename)
822
+ repo_type = Utils::Validators.validate_repo_type(repo_type)
823
+ revision = Utils::Validators.validate_revision(revision)
824
+
825
+ endpoint ||= DurableHuggingfaceHub.configuration.endpoint
826
+ endpoint = endpoint.chomp("/")
827
+
828
+ "#{endpoint}/#{repo_type}s/#{repo_id}/resolve/#{revision}/#{filename}"
829
+ end
830
+ end
831
+ end