durable_huggingface_hub 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.editorconfig +29 -0
- data/.rubocop.yml +108 -0
- data/CHANGELOG.md +127 -0
- data/README.md +547 -0
- data/Rakefile +106 -0
- data/devenv.lock +171 -0
- data/devenv.nix +15 -0
- data/devenv.yaml +8 -0
- data/huggingface_hub.gemspec +63 -0
- data/lib/durable_huggingface_hub/authentication.rb +245 -0
- data/lib/durable_huggingface_hub/cache.rb +508 -0
- data/lib/durable_huggingface_hub/configuration.rb +191 -0
- data/lib/durable_huggingface_hub/constants.rb +145 -0
- data/lib/durable_huggingface_hub/errors.rb +412 -0
- data/lib/durable_huggingface_hub/file_download.rb +831 -0
- data/lib/durable_huggingface_hub/hf_api.rb +1278 -0
- data/lib/durable_huggingface_hub/repo_card.rb +430 -0
- data/lib/durable_huggingface_hub/types/cache_info.rb +298 -0
- data/lib/durable_huggingface_hub/types/commit_info.rb +149 -0
- data/lib/durable_huggingface_hub/types/dataset_info.rb +158 -0
- data/lib/durable_huggingface_hub/types/model_info.rb +154 -0
- data/lib/durable_huggingface_hub/types/space_info.rb +158 -0
- data/lib/durable_huggingface_hub/types/user.rb +179 -0
- data/lib/durable_huggingface_hub/types.rb +205 -0
- data/lib/durable_huggingface_hub/utils/auth.rb +174 -0
- data/lib/durable_huggingface_hub/utils/headers.rb +220 -0
- data/lib/durable_huggingface_hub/utils/http.rb +329 -0
- data/lib/durable_huggingface_hub/utils/paths.rb +230 -0
- data/lib/durable_huggingface_hub/utils/progress.rb +217 -0
- data/lib/durable_huggingface_hub/utils/retry.rb +165 -0
- data/lib/durable_huggingface_hub/utils/validators.rb +236 -0
- data/lib/durable_huggingface_hub/version.rb +8 -0
- data/lib/huggingface_hub.rb +205 -0
- metadata +334 -0
|
@@ -0,0 +1,831 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "pathname"
|
|
4
|
+
require "digest"
|
|
5
|
+
require "fileutils"
|
|
6
|
+
require "json"
|
|
7
|
+
require_relative "utils/progress"
|
|
8
|
+
require_relative "hf_api"
|
|
9
|
+
require_relative "errors"
|
|
10
|
+
|
|
11
|
+
module DurableHuggingfaceHub
|
|
12
|
+
# File download functionality with caching and ETag support.
|
|
13
|
+
#
|
|
14
|
+
# This module provides utilities for downloading files from the HuggingFace Hub
|
|
15
|
+
# with intelligent caching, resume support, and validation using ETags.
|
|
16
|
+
#
|
|
17
|
+
# @example Download a model file
|
|
18
|
+
# path = DurableHuggingfaceHub::FileDownload.hf_hub_download(
|
|
19
|
+
# repo_id: "bert-base-uncased",
|
|
20
|
+
# filename: "config.json"
|
|
21
|
+
# )
|
|
22
|
+
# config = JSON.parse(File.read(path))
|
|
23
|
+
#
|
|
24
|
+
# @example Download with custom cache directory
|
|
25
|
+
# path = DurableHuggingfaceHub::FileDownload.hf_hub_download(
|
|
26
|
+
# repo_id: "gpt2",
|
|
27
|
+
# filename: "pytorch_model.bin",
|
|
28
|
+
# cache_dir: "/custom/cache"
|
|
29
|
+
# )
|
|
30
|
+
module FileDownload
|
|
31
|
+
# Default cache directory location
|
|
32
|
+
DEFAULT_CACHE_DIR = Pathname.new(Dir.home).join(".cache", "huggingface", "hub")
|
|
33
|
+
|
|
34
|
+
# Metadata file name for cache entries
|
|
35
|
+
METADATA_FILENAME = ".metadata.json"
|
|
36
|
+
|
|
37
|
+
# Lock file suffix for atomic operations
|
|
38
|
+
LOCK_SUFFIX = ".lock"
|
|
39
|
+
|
|
40
|
+
# Downloads a file from the HuggingFace Hub with caching.
|
|
41
|
+
#
|
|
42
|
+
# This method downloads a file from a HuggingFace Hub repository and caches
|
|
43
|
+
# it locally. It uses ETags to avoid re-downloading unchanged files and
|
|
44
|
+
# supports atomic operations to prevent cache corruption.
|
|
45
|
+
#
|
|
46
|
+
# @param repo_id [String] Repository ID (e.g., "bert-base-uncased")
|
|
47
|
+
# @param filename [String] Path to file in repository
|
|
48
|
+
# @param repo_type [String] Type of repository ("model", "dataset", or "space")
|
|
49
|
+
# @param revision [String, nil] Git revision (branch, tag, or commit SHA)
|
|
50
|
+
# @param cache_dir [String, Pathname, nil] Custom cache directory
|
|
51
|
+
# @param force_download [Boolean] Force re-download even if cached
|
|
52
|
+
# @param token [String, nil] HuggingFace API token
|
|
53
|
+
# @param local_files_only [Boolean] Only use cached files, don't download
|
|
54
|
+
# @param headers [Hash, nil] Additional HTTP headers
|
|
55
|
+
# @param progress [Proc, nil] Progress callback (receives current, total, percentage)
|
|
56
|
+
#
|
|
57
|
+
# @return [Pathname] Path to the downloaded (or cached) file
|
|
58
|
+
#
|
|
59
|
+
# @raise [RepositoryNotFoundError] If repository doesn't exist
|
|
60
|
+
# @raise [EntryNotFoundError] If file doesn't exist in repository
|
|
61
|
+
# @raise [LocalEntryNotFoundError] If local_files_only=true and file not cached
|
|
62
|
+
# @raise [ValidationError] If parameters are invalid
|
|
63
|
+
#
|
|
64
|
+
# @example Basic download
|
|
65
|
+
# path = FileDownload.hf_hub_download(
|
|
66
|
+
# repo_id: "bert-base-uncased",
|
|
67
|
+
# filename: "config.json"
|
|
68
|
+
# )
|
|
69
|
+
#
|
|
70
|
+
# @example Download specific revision
|
|
71
|
+
# path = FileDownload.hf_hub_download(
|
|
72
|
+
# repo_id: "gpt2",
|
|
73
|
+
# filename: "pytorch_model.bin",
|
|
74
|
+
# revision: "main"
|
|
75
|
+
# )
|
|
76
|
+
#
|
|
77
|
+
# @example Force re-download
|
|
78
|
+
# path = FileDownload.hf_hub_download(
|
|
79
|
+
# repo_id: "bert-base-uncased",
|
|
80
|
+
# filename: "config.json",
|
|
81
|
+
# force_download: true
|
|
82
|
+
# )
|
|
83
|
+
def self.hf_hub_download(
|
|
84
|
+
repo_id:,
|
|
85
|
+
filename:,
|
|
86
|
+
repo_type: "model",
|
|
87
|
+
revision: nil,
|
|
88
|
+
cache_dir: nil,
|
|
89
|
+
force_download: false,
|
|
90
|
+
token: nil,
|
|
91
|
+
local_files_only: false,
|
|
92
|
+
headers: nil,
|
|
93
|
+
progress: nil
|
|
94
|
+
)
|
|
95
|
+
# Validate inputs
|
|
96
|
+
repo_id = Utils::Validators.validate_repo_id(repo_id)
|
|
97
|
+
filename = Utils::Validators.validate_filename(filename)
|
|
98
|
+
repo_type = Utils::Validators.validate_repo_type(repo_type)
|
|
99
|
+
revision = Utils::Validators.validate_revision(revision) if revision
|
|
100
|
+
|
|
101
|
+
# Get cache directory
|
|
102
|
+
cache_dir = resolve_cache_dir(cache_dir)
|
|
103
|
+
|
|
104
|
+
# Build storage paths
|
|
105
|
+
storage_folder = get_storage_folder(repo_id, repo_type: repo_type, cache_dir: cache_dir)
|
|
106
|
+
revision ||= "main"
|
|
107
|
+
|
|
108
|
+
# Check if we can use local files only
|
|
109
|
+
if local_files_only
|
|
110
|
+
cached_path = find_cached_file(storage_folder, filename, revision)
|
|
111
|
+
if cached_path
|
|
112
|
+
return cached_path
|
|
113
|
+
else
|
|
114
|
+
raise LocalEntryNotFoundError.new(
|
|
115
|
+
"File #{filename} not found in local cache for #{repo_id}@#{revision}. " \
|
|
116
|
+
"Cannot download because local_files_only=true"
|
|
117
|
+
)
|
|
118
|
+
end
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
# Get token for authentication
|
|
122
|
+
token = Utils::Auth.get_token(token: token)
|
|
123
|
+
|
|
124
|
+
# Download or retrieve from cache
|
|
125
|
+
download_file(
|
|
126
|
+
repo_id: repo_id,
|
|
127
|
+
filename: filename,
|
|
128
|
+
repo_type: repo_type,
|
|
129
|
+
revision: revision,
|
|
130
|
+
storage_folder: storage_folder,
|
|
131
|
+
force_download: force_download,
|
|
132
|
+
token: token,
|
|
133
|
+
headers: headers,
|
|
134
|
+
progress: progress
|
|
135
|
+
)
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
# Downloads an entire repository snapshot from the HuggingFace Hub with caching.
|
|
139
|
+
#
|
|
140
|
+
# This method downloads all files from a HuggingFace Hub repository for a given
|
|
141
|
+
# revision and stores them in a local cache directory. It leverages `hf_hub_download`
|
|
142
|
+
# for individual file downloads and supports filtering by patterns.
|
|
143
|
+
#
|
|
144
|
+
# The method implements robust offline fallback: if the Hub is unavailable or network
|
|
145
|
+
# is down, it will try to use locally cached files. It properly handles commit hash
|
|
146
|
+
# resolution for branches and tags.
|
|
147
|
+
#
|
|
148
|
+
# @param repo_id [String] Repository ID (e.g., "bert-base-uncased")
|
|
149
|
+
# @param repo_type [String, Symbol] Type of repository ("model", "dataset", or "space")
|
|
150
|
+
# @param revision [String, nil] Git revision (branch, tag, or commit SHA). Defaults to "main"
|
|
151
|
+
# @param cache_dir [String, Pathname, nil] Custom cache directory
|
|
152
|
+
# @param local_dir [String, Pathname, nil] Custom directory to copy the snapshot to.
|
|
153
|
+
# If nil, the snapshot will remain in the cache.
|
|
154
|
+
# @param force_download [Boolean] Force re-download even if cached
|
|
155
|
+
# @param token [String, nil] HuggingFace API token
|
|
156
|
+
# @param local_files_only [Boolean] Only use cached files, don't download
|
|
157
|
+
# @param allow_patterns [Array<String>, String, nil] Glob patterns to include (e.g., "*.json", ["*.py", "*.md"])
|
|
158
|
+
# @param ignore_patterns [Array<String>, String, nil] Glob patterns to exclude (e.g., "*.bin", ["*.safetensors"])
|
|
159
|
+
# @param max_workers [Integer] Number of concurrent downloads (default: 8)
|
|
160
|
+
# @param progress [Proc, nil] Progress callback (receives current, total, percentage)
|
|
161
|
+
#
|
|
162
|
+
# @return [Pathname] Path to the downloaded (or cached) snapshot directory
|
|
163
|
+
#
|
|
164
|
+
# @raise [RepositoryNotFoundError] If repository doesn't exist
|
|
165
|
+
# @raise [RevisionNotFoundError] If revision doesn't exist
|
|
166
|
+
# @raise [LocalEntryNotFoundError] If local_files_only=true and files not cached
|
|
167
|
+
# @raise [ValidationError] If parameters are invalid
|
|
168
|
+
#
|
|
169
|
+
# @example Download entire model repository
|
|
170
|
+
# local_dir = FileDownload.snapshot_download(
|
|
171
|
+
# repo_id: "gpt2",
|
|
172
|
+
# revision: "main"
|
|
173
|
+
# )
|
|
174
|
+
#
|
|
175
|
+
# @example Download only specific file patterns
|
|
176
|
+
# filtered_dir = FileDownload.snapshot_download(
|
|
177
|
+
# repo_id: "bert-base-uncased",
|
|
178
|
+
# allow_patterns: ["*.json", "*.txt"],
|
|
179
|
+
# ignore_patterns: ["*.bin"]
|
|
180
|
+
# )
|
|
181
|
+
#
|
|
182
|
+
# @example Download with parallel downloads
|
|
183
|
+
# snapshot = FileDownload.snapshot_download(
|
|
184
|
+
# repo_id: "bert-base-uncased",
|
|
185
|
+
# max_workers: 16
|
|
186
|
+
# )
|
|
187
|
+
def self.snapshot_download(
|
|
188
|
+
repo_id:,
|
|
189
|
+
repo_type: "model",
|
|
190
|
+
revision: nil,
|
|
191
|
+
cache_dir: nil,
|
|
192
|
+
local_dir: nil,
|
|
193
|
+
force_download: false,
|
|
194
|
+
token: nil,
|
|
195
|
+
local_files_only: false,
|
|
196
|
+
allow_patterns: nil,
|
|
197
|
+
ignore_patterns: nil,
|
|
198
|
+
max_workers: 8,
|
|
199
|
+
progress: nil
|
|
200
|
+
)
|
|
201
|
+
# Validate inputs
|
|
202
|
+
repo_id = Utils::Validators.validate_repo_id(repo_id)
|
|
203
|
+
repo_type = Utils::Validators.validate_repo_type(repo_type)
|
|
204
|
+
revision = Utils::Validators.validate_revision(revision) if revision
|
|
205
|
+
revision ||= "main"
|
|
206
|
+
|
|
207
|
+
# Get cache directory and storage folder
|
|
208
|
+
cache_dir = resolve_cache_dir(cache_dir)
|
|
209
|
+
storage_folder = get_storage_folder(repo_id, repo_type: repo_type, cache_dir: cache_dir)
|
|
210
|
+
|
|
211
|
+
# Get token for authentication
|
|
212
|
+
token = Utils::Auth.get_token(token: token)
|
|
213
|
+
|
|
214
|
+
# Try to fetch repository info from Hub
|
|
215
|
+
repo_info = nil
|
|
216
|
+
api_call_error = nil
|
|
217
|
+
|
|
218
|
+
unless local_files_only
|
|
219
|
+
begin
|
|
220
|
+
# Initialize HfApi client
|
|
221
|
+
api = HfApi.new(token: token)
|
|
222
|
+
repo_info = api.repo_info(repo_id, repo_type: repo_type, revision: revision)
|
|
223
|
+
rescue StandardError => e
|
|
224
|
+
# Store error but continue - we might be able to use cached files
|
|
225
|
+
api_call_error = e
|
|
226
|
+
end
|
|
227
|
+
end
|
|
228
|
+
|
|
229
|
+
# If we couldn't get repo_info, try to use cached files
|
|
230
|
+
if repo_info.nil?
|
|
231
|
+
# Try to resolve commit hash from revision
|
|
232
|
+
commit_hash = nil
|
|
233
|
+
|
|
234
|
+
# Check if revision is already a commit hash
|
|
235
|
+
if revision.match?(/^[0-9a-f]{40}$/)
|
|
236
|
+
commit_hash = revision
|
|
237
|
+
else
|
|
238
|
+
# Try to read commit hash from refs
|
|
239
|
+
ref_file = storage_folder.join("refs", revision)
|
|
240
|
+
if ref_file.exist?
|
|
241
|
+
commit_hash = ref_file.read.strip
|
|
242
|
+
end
|
|
243
|
+
end
|
|
244
|
+
|
|
245
|
+
# Try to locate snapshot folder for this commit hash
|
|
246
|
+
if commit_hash && local_dir.nil?
|
|
247
|
+
snapshot_folder = storage_folder.join("snapshots", commit_hash)
|
|
248
|
+
if snapshot_folder.exist? && snapshot_folder.directory?
|
|
249
|
+
# Snapshot folder exists => return it
|
|
250
|
+
return snapshot_folder
|
|
251
|
+
end
|
|
252
|
+
end
|
|
253
|
+
|
|
254
|
+
# If local_dir is specified and exists, return it
|
|
255
|
+
if local_dir
|
|
256
|
+
local_dir_path = Utils::Paths.expand_path(local_dir)
|
|
257
|
+
if local_dir_path.exist? && local_dir_path.directory? && !local_dir_path.children.empty?
|
|
258
|
+
warn "Returning existing local_dir #{local_dir_path} as remote repo cannot be accessed"
|
|
259
|
+
return local_dir_path
|
|
260
|
+
end
|
|
261
|
+
end
|
|
262
|
+
|
|
263
|
+
# Could not find cached files - raise appropriate error
|
|
264
|
+
if local_files_only
|
|
265
|
+
raise LocalEntryNotFoundError.new(
|
|
266
|
+
"Cannot find an appropriate cached snapshot folder for #{repo_id}@#{revision}. " \
|
|
267
|
+
"To enable downloads, set local_files_only=false"
|
|
268
|
+
)
|
|
269
|
+
elsif api_call_error.is_a?(RepositoryNotFoundError) || api_call_error.is_a?(RevisionNotFoundError)
|
|
270
|
+
raise api_call_error
|
|
271
|
+
else
|
|
272
|
+
raise LocalEntryNotFoundError.new(
|
|
273
|
+
"An error occurred while trying to locate files on the Hub, and we cannot find " \
|
|
274
|
+
"the appropriate snapshot folder for #{repo_id}@#{revision} in the local cache. " \
|
|
275
|
+
"Please check your internet connection and try again. Error: #{api_call_error&.message}"
|
|
276
|
+
)
|
|
277
|
+
end
|
|
278
|
+
end
|
|
279
|
+
|
|
280
|
+
# At this point, we have repo_info with a valid commit hash
|
|
281
|
+
commit_hash = repo_info.sha
|
|
282
|
+
raise DurableHuggingfaceHubError, "Repo info must have a commit SHA" unless commit_hash
|
|
283
|
+
|
|
284
|
+
# Determine snapshot folder
|
|
285
|
+
snapshot_folder = storage_folder.join("snapshots", commit_hash)
|
|
286
|
+
|
|
287
|
+
# Store ref if revision is not a commit hash
|
|
288
|
+
if revision != commit_hash
|
|
289
|
+
update_refs(storage_folder, revision, commit_hash)
|
|
290
|
+
end
|
|
291
|
+
|
|
292
|
+
# Get list of files from repo_info
|
|
293
|
+
all_files = if repo_info.respond_to?(:siblings) && repo_info.siblings
|
|
294
|
+
repo_info.siblings.map { |sibling| sibling[:rfilename] || sibling["rfilename"] }.compact
|
|
295
|
+
else
|
|
296
|
+
# Fallback to API call if siblings not available
|
|
297
|
+
api.list_repo_files(repo_id: repo_id, repo_type: repo_type, revision: commit_hash)
|
|
298
|
+
end
|
|
299
|
+
|
|
300
|
+
# Filter files based on allow_patterns and ignore_patterns
|
|
301
|
+
filtered_files = Utils::Paths.filter_repo_objects(all_files, allow_patterns: allow_patterns, ignore_patterns: ignore_patterns)
|
|
302
|
+
|
|
303
|
+
# Download files (with parallelization if max_workers > 1)
|
|
304
|
+
if max_workers > 1
|
|
305
|
+
download_files_parallel(
|
|
306
|
+
repo_id: repo_id,
|
|
307
|
+
files: filtered_files,
|
|
308
|
+
repo_type: repo_type,
|
|
309
|
+
revision: commit_hash,
|
|
310
|
+
cache_dir: cache_dir,
|
|
311
|
+
force_download: force_download,
|
|
312
|
+
token: token,
|
|
313
|
+
max_workers: max_workers,
|
|
314
|
+
progress: progress
|
|
315
|
+
)
|
|
316
|
+
else
|
|
317
|
+
# Sequential download
|
|
318
|
+
filtered_files.each do |filename|
|
|
319
|
+
hf_hub_download(
|
|
320
|
+
repo_id: repo_id,
|
|
321
|
+
filename: filename,
|
|
322
|
+
repo_type: repo_type,
|
|
323
|
+
revision: commit_hash,
|
|
324
|
+
cache_dir: cache_dir,
|
|
325
|
+
force_download: force_download,
|
|
326
|
+
token: token,
|
|
327
|
+
local_files_only: false,
|
|
328
|
+
progress: progress
|
|
329
|
+
)
|
|
330
|
+
end
|
|
331
|
+
end
|
|
332
|
+
|
|
333
|
+
# If local_dir is specified, copy the snapshot there
|
|
334
|
+
if local_dir
|
|
335
|
+
local_dir_path = Utils::Paths.expand_path(local_dir)
|
|
336
|
+
copy_snapshot_to_local_dir(snapshot_folder, local_dir_path)
|
|
337
|
+
return local_dir_path.realpath
|
|
338
|
+
end
|
|
339
|
+
|
|
340
|
+
snapshot_folder
|
|
341
|
+
end
|
|
342
|
+
|
|
343
|
+
# Gets the cache directory for a repository.
|
|
344
|
+
#
|
|
345
|
+
# @param repo_id [String] Repository ID
|
|
346
|
+
# @param repo_type [String] Type of repository
|
|
347
|
+
# @param cache_dir [String, Pathname, nil] Custom cache directory
|
|
348
|
+
# @return [Pathname] Storage folder path
|
|
349
|
+
def self.get_storage_folder(repo_id, repo_type: "model", cache_dir: nil)
|
|
350
|
+
cache_dir = resolve_cache_dir(cache_dir)
|
|
351
|
+
|
|
352
|
+
# Create a unique folder name based on repo_id and type
|
|
353
|
+
# Format: models--namespace--name or models--name
|
|
354
|
+
repo_id_parts = repo_id.split("/")
|
|
355
|
+
if repo_id_parts.length == 2
|
|
356
|
+
folder_name = "#{repo_type}s--#{repo_id_parts[0]}--#{repo_id_parts[1]}"
|
|
357
|
+
else
|
|
358
|
+
folder_name = "#{repo_type}s--#{repo_id}"
|
|
359
|
+
end
|
|
360
|
+
|
|
361
|
+
cache_dir.join(folder_name)
|
|
362
|
+
end
|
|
363
|
+
|
|
364
|
+
# Resolves the cache directory to use.
|
|
365
|
+
#
|
|
366
|
+
# @param cache_dir [String, Pathname, nil] Custom cache directory
|
|
367
|
+
# @return [Pathname] Resolved cache directory
|
|
368
|
+
def self.resolve_cache_dir(cache_dir)
|
|
369
|
+
if cache_dir
|
|
370
|
+
Utils::Paths.expand_path(cache_dir)
|
|
371
|
+
elsif ENV["HF_HOME"]
|
|
372
|
+
Pathname.new(ENV["HF_HOME"]).join("hub")
|
|
373
|
+
elsif ENV["HUGGINGFACE_HUB_CACHE"]
|
|
374
|
+
Pathname.new(ENV["HUGGINGFACE_HUB_CACHE"])
|
|
375
|
+
else
|
|
376
|
+
DEFAULT_CACHE_DIR
|
|
377
|
+
end
|
|
378
|
+
end
|
|
379
|
+
|
|
380
|
+
# Finds a cached file for a specific revision.
|
|
381
|
+
#
|
|
382
|
+
# @param storage_folder [Pathname] Repository storage folder
|
|
383
|
+
# @param filename [String] File path in repository
|
|
384
|
+
# @param revision [String] Git revision
|
|
385
|
+
# @return [Pathname, nil] Path to cached file or nil if not found
|
|
386
|
+
def self.find_cached_file(storage_folder, filename, revision)
|
|
387
|
+
# Look for snapshot folder for this revision
|
|
388
|
+
snapshots_folder = storage_folder.join("snapshots")
|
|
389
|
+
return nil unless snapshots_folder.exist?
|
|
390
|
+
|
|
391
|
+
# Try to find by revision folder
|
|
392
|
+
revision_folder = snapshots_folder.join(revision)
|
|
393
|
+
if revision_folder.exist?
|
|
394
|
+
file_path = revision_folder.join(filename)
|
|
395
|
+
return file_path if file_path.exist?
|
|
396
|
+
end
|
|
397
|
+
|
|
398
|
+
# Try to find in refs folder
|
|
399
|
+
refs_folder = storage_folder.join("refs")
|
|
400
|
+
if refs_folder.exist?
|
|
401
|
+
ref_file = refs_folder.join(revision)
|
|
402
|
+
if ref_file.exist?
|
|
403
|
+
commit_hash = ref_file.read.strip
|
|
404
|
+
commit_folder = snapshots_folder.join(commit_hash)
|
|
405
|
+
if commit_folder.exist?
|
|
406
|
+
file_path = commit_folder.join(filename)
|
|
407
|
+
return file_path if file_path.exist?
|
|
408
|
+
end
|
|
409
|
+
end
|
|
410
|
+
end
|
|
411
|
+
|
|
412
|
+
nil
|
|
413
|
+
end
|
|
414
|
+
|
|
415
|
+
# Downloads a file and stores it in the cache.
|
|
416
|
+
#
|
|
417
|
+
# @param repo_id [String] Repository ID
|
|
418
|
+
# @param filename [String] File path in repository
|
|
419
|
+
# @param repo_type [String] Type of repository
|
|
420
|
+
# @param revision [String] Git revision
|
|
421
|
+
# @param storage_folder [Pathname] Repository storage folder
|
|
422
|
+
# @param force_download [Boolean] Force re-download
|
|
423
|
+
# @param token [String, nil] HuggingFace API token
|
|
424
|
+
# @param headers [Hash, nil] Additional HTTP headers
|
|
425
|
+
# @param progress [Proc, nil] Progress callback
|
|
426
|
+
# @return [Pathname] Path to downloaded file
|
|
427
|
+
def self.download_file(
|
|
428
|
+
repo_id:,
|
|
429
|
+
filename:,
|
|
430
|
+
repo_type:,
|
|
431
|
+
revision:,
|
|
432
|
+
storage_folder:,
|
|
433
|
+
force_download:,
|
|
434
|
+
token:,
|
|
435
|
+
headers:,
|
|
436
|
+
progress:
|
|
437
|
+
)
|
|
438
|
+
# Create HTTP client
|
|
439
|
+
client = Utils::HttpClient.new(token: token, headers: headers)
|
|
440
|
+
|
|
441
|
+
# Build URL for file
|
|
442
|
+
url_path = "/#{repo_type}s/#{repo_id}/resolve/#{revision}/#{filename}"
|
|
443
|
+
|
|
444
|
+
# Get metadata about the file (including ETag and commit hash)
|
|
445
|
+
metadata = get_file_metadata(client, url_path)
|
|
446
|
+
|
|
447
|
+
# Determine final storage location
|
|
448
|
+
commit_hash = metadata[:commit_hash] || revision
|
|
449
|
+
blob_path = storage_folder.join("blobs", metadata[:etag])
|
|
450
|
+
snapshot_path = storage_folder.join("snapshots", commit_hash, filename)
|
|
451
|
+
|
|
452
|
+
# Check if we already have this file (by ETag or snapshot file)
|
|
453
|
+
unless force_download
|
|
454
|
+
if blob_path.exist? && verify_blob(blob_path, metadata[:etag])
|
|
455
|
+
# File exists in blob storage, create symlink if needed
|
|
456
|
+
ensure_snapshot_link(blob_path, snapshot_path)
|
|
457
|
+
update_refs(storage_folder, revision, commit_hash)
|
|
458
|
+
return snapshot_path
|
|
459
|
+
elsif snapshot_path.exist?
|
|
460
|
+
# File exists in snapshot, assume it's valid
|
|
461
|
+
update_refs(storage_folder, revision, commit_hash)
|
|
462
|
+
return snapshot_path
|
|
463
|
+
end
|
|
464
|
+
end
|
|
465
|
+
|
|
466
|
+
# Download the file to blob storage
|
|
467
|
+
download_to_blob(client, url_path, blob_path, metadata, progress)
|
|
468
|
+
|
|
469
|
+
# Create snapshot symlink
|
|
470
|
+
ensure_snapshot_link(blob_path, snapshot_path)
|
|
471
|
+
|
|
472
|
+
# Update refs
|
|
473
|
+
update_refs(storage_folder, revision, commit_hash)
|
|
474
|
+
|
|
475
|
+
snapshot_path
|
|
476
|
+
end
|
|
477
|
+
|
|
478
|
+
# Gets metadata about a file from the Hub.
|
|
479
|
+
#
|
|
480
|
+
# @param client [Utils::HttpClient] HTTP client
|
|
481
|
+
# @param url_path [String] URL path to file
|
|
482
|
+
# @return [Hash] Metadata including :etag, :size, :commit_hash
|
|
483
|
+
def self.get_file_metadata(client, url_path)
|
|
484
|
+
response = client.head(url_path)
|
|
485
|
+
|
|
486
|
+
# Extract metadata from headers (response is now a Faraday::Response object)
|
|
487
|
+
headers = response.headers
|
|
488
|
+
{
|
|
489
|
+
etag: extract_etag(headers["etag"] || headers["x-linked-etag"]),
|
|
490
|
+
size: headers["x-linked-size"]&.to_i,
|
|
491
|
+
commit_hash: headers["x-repo-commit"]
|
|
492
|
+
}
|
|
493
|
+
end
|
|
494
|
+
|
|
495
|
+
# Extracts clean ETag from header value.
|
|
496
|
+
#
|
|
497
|
+
# @param etag [String, nil] Raw ETag header value
|
|
498
|
+
# @return [String, nil] Cleaned ETag
|
|
499
|
+
def self.extract_etag(etag)
|
|
500
|
+
return nil unless etag
|
|
501
|
+
|
|
502
|
+
# Remove quotes and W/ prefix
|
|
503
|
+
etag = etag.gsub(/^W\//, "").gsub(/^"/, "").gsub(/"$/, "")
|
|
504
|
+
etag.empty? ? nil : etag
|
|
505
|
+
end
|
|
506
|
+
|
|
507
|
+
# Verifies a blob file matches the expected ETag.
|
|
508
|
+
#
|
|
509
|
+
# @param blob_path [Pathname] Path to blob file
|
|
510
|
+
# @param etag [String] Expected ETag
|
|
511
|
+
# @return [Boolean] True if blob is valid
|
|
512
|
+
def self.verify_blob(blob_path, etag)
|
|
513
|
+
return false unless blob_path.exist?
|
|
514
|
+
|
|
515
|
+
# First check filename matches ETag (fast check)
|
|
516
|
+
return false unless blob_path.basename.to_s == etag
|
|
517
|
+
|
|
518
|
+
# For more robust verification, we could compute the actual ETag
|
|
519
|
+
# from file content, but for now we trust the filename-based approach
|
|
520
|
+
# used by HuggingFace Hub
|
|
521
|
+
true
|
|
522
|
+
end
|
|
523
|
+
|
|
524
|
+
# Downloads a file to blob storage.
|
|
525
|
+
#
|
|
526
|
+
# @param client [Utils::HttpClient] HTTP client
|
|
527
|
+
# @param url_path [String] URL path to file
|
|
528
|
+
# @param blob_path [Pathname] Destination blob path
|
|
529
|
+
# @param metadata [Hash] File metadata
|
|
530
|
+
# @param progress [Proc, nil] Progress callback
|
|
531
|
+
def self.download_to_blob(client, url_path, blob_path, metadata, progress)
|
|
532
|
+
# Ensure blobs directory exists
|
|
533
|
+
blob_path.dirname.mkpath
|
|
534
|
+
|
|
535
|
+
# Download to temporary file first (atomic operation)
|
|
536
|
+
temp_path = Pathname.new("#{blob_path}.tmp.#{Process.pid}")
|
|
537
|
+
|
|
538
|
+
# Create progress tracker
|
|
539
|
+
progress_tracker = if progress
|
|
540
|
+
Utils::Progress.new(total: metadata[:size], callback: progress)
|
|
541
|
+
else
|
|
542
|
+
Utils::NullProgress.new
|
|
543
|
+
end
|
|
544
|
+
|
|
545
|
+
begin
|
|
546
|
+
# Download file
|
|
547
|
+
response = client.request(:get, url_path) do |req|
|
|
548
|
+
req.options.on_data = proc do |chunk, _overall_received_bytes, _env|
|
|
549
|
+
File.open(temp_path, "ab") { |f| f.write(chunk) }
|
|
550
|
+
progress_tracker.update(chunk.bytesize)
|
|
551
|
+
end
|
|
552
|
+
end
|
|
553
|
+
|
|
554
|
+
# Verify download
|
|
555
|
+
unless temp_path.exist? && temp_path.size.positive?
|
|
556
|
+
raise DurableHuggingfaceHubError, "Download failed: file is empty or missing"
|
|
557
|
+
end
|
|
558
|
+
|
|
559
|
+
# Move to final location atomically
|
|
560
|
+
FileUtils.mv(temp_path, blob_path)
|
|
561
|
+
|
|
562
|
+
# Mark progress as finished
|
|
563
|
+
progress_tracker.finish
|
|
564
|
+
|
|
565
|
+
# Write metadata
|
|
566
|
+
write_blob_metadata(blob_path, metadata)
|
|
567
|
+
ensure
|
|
568
|
+
# Clean up temp file if it still exists
|
|
569
|
+
temp_path.unlink if temp_path.exist?
|
|
570
|
+
end
|
|
571
|
+
end
|
|
572
|
+
|
|
573
|
+
# Writes metadata for a blob file.
|
|
574
|
+
#
|
|
575
|
+
# @param blob_path [Pathname] Path to blob file
|
|
576
|
+
# @param metadata [Hash] Metadata to write
|
|
577
|
+
def self.write_blob_metadata(blob_path, metadata)
|
|
578
|
+
metadata_path = Pathname.new("#{blob_path}#{METADATA_FILENAME}")
|
|
579
|
+
metadata_path.write(JSON.pretty_generate(metadata))
|
|
580
|
+
end
|
|
581
|
+
|
|
582
|
+
# Ensures a symlink exists from snapshot to blob.
|
|
583
|
+
#
|
|
584
|
+
# @param blob_path [Pathname] Source blob path
|
|
585
|
+
# @param snapshot_path [Pathname] Destination snapshot path
|
|
586
|
+
def self.ensure_snapshot_link(blob_path, snapshot_path)
|
|
587
|
+
# Create snapshot directory if needed
|
|
588
|
+
snapshot_path.dirname.mkpath
|
|
589
|
+
|
|
590
|
+
# Remove existing file/link if present
|
|
591
|
+
snapshot_path.unlink if snapshot_path.exist? || snapshot_path.symlink?
|
|
592
|
+
|
|
593
|
+
# Create relative symlink
|
|
594
|
+
relative_blob_path = blob_path.relative_path_from(snapshot_path.dirname)
|
|
595
|
+
snapshot_path.make_symlink(relative_blob_path)
|
|
596
|
+
rescue NotImplementedError
|
|
597
|
+
# System doesn't support symlinks, copy instead
|
|
598
|
+
FileUtils.cp(blob_path, snapshot_path)
|
|
599
|
+
end
|
|
600
|
+
|
|
601
|
+
# Updates refs to point to the latest commit hash.
|
|
602
|
+
#
|
|
603
|
+
# @param storage_folder [Pathname] Repository storage folder
|
|
604
|
+
# @param revision [String] Revision name (branch/tag)
|
|
605
|
+
# @param commit_hash [String] Commit hash
|
|
606
|
+
def self.update_refs(storage_folder, revision, commit_hash)
|
|
607
|
+
return if revision == commit_hash # Don't create ref for commit hashes
|
|
608
|
+
|
|
609
|
+
refs_folder = storage_folder.join("refs")
|
|
610
|
+
refs_folder.mkpath
|
|
611
|
+
|
|
612
|
+
ref_file = refs_folder.join(revision)
|
|
613
|
+
ref_file.write(commit_hash)
|
|
614
|
+
end
|
|
615
|
+
|
|
616
|
+
# Filters files based on glob patterns.
|
|
617
|
+
#
|
|
618
|
+
# @param files [Array<String>] List of file paths
|
|
619
|
+
# @param allow_patterns [Array<String>, String, nil] Glob patterns to include
|
|
620
|
+
# @param ignore_patterns [Array<String>, String, nil] Glob patterns to exclude
|
|
621
|
+
# @return [Array<String>] Filtered list of files
|
|
622
|
+
def self.filter_repo_files(files, allow_patterns: nil, ignore_patterns: nil)
|
|
623
|
+
filtered = files
|
|
624
|
+
|
|
625
|
+
# Apply allow_patterns if specified
|
|
626
|
+
if allow_patterns
|
|
627
|
+
patterns = Array(allow_patterns)
|
|
628
|
+
filtered = filtered.select do |filename|
|
|
629
|
+
patterns.any? { |pattern| File.fnmatch(pattern, filename, File::FNM_PATHNAME) }
|
|
630
|
+
end
|
|
631
|
+
end
|
|
632
|
+
|
|
633
|
+
# Apply ignore_patterns if specified
|
|
634
|
+
if ignore_patterns
|
|
635
|
+
patterns = Array(ignore_patterns)
|
|
636
|
+
filtered = filtered.reject do |filename|
|
|
637
|
+
patterns.any? { |pattern| File.fnmatch(pattern, filename, File::FNM_PATHNAME) }
|
|
638
|
+
end
|
|
639
|
+
end
|
|
640
|
+
|
|
641
|
+
filtered
|
|
642
|
+
end
|
|
643
|
+
|
|
644
|
+
# Downloads multiple files in parallel using threads.
|
|
645
|
+
#
|
|
646
|
+
# @param repo_id [String] Repository ID
|
|
647
|
+
# @param files [Array<String>] List of files to download
|
|
648
|
+
# @param repo_type [String] Repository type
|
|
649
|
+
# @param revision [String] Git revision
|
|
650
|
+
# @param cache_dir [Pathname] Cache directory
|
|
651
|
+
# @param force_download [Boolean] Force re-download
|
|
652
|
+
# @param token [String, nil] Authentication token
|
|
653
|
+
# @param max_workers [Integer] Number of concurrent threads
|
|
654
|
+
# @param progress [Proc, nil] Progress callback
|
|
655
|
+
def self.download_files_parallel(
|
|
656
|
+
repo_id:,
|
|
657
|
+
files:,
|
|
658
|
+
repo_type:,
|
|
659
|
+
revision:,
|
|
660
|
+
cache_dir:,
|
|
661
|
+
force_download:,
|
|
662
|
+
token:,
|
|
663
|
+
max_workers:,
|
|
664
|
+
progress:
|
|
665
|
+
)
|
|
666
|
+
require "thread"
|
|
667
|
+
|
|
668
|
+
# Create a queue of files to download
|
|
669
|
+
queue = Queue.new
|
|
670
|
+
files.each { |file| queue << file }
|
|
671
|
+
|
|
672
|
+
# Track completed downloads
|
|
673
|
+
completed = 0
|
|
674
|
+
total = files.length
|
|
675
|
+
mutex = Mutex.new
|
|
676
|
+
|
|
677
|
+
# Create worker threads
|
|
678
|
+
threads = Array.new([max_workers, files.length].min) do
|
|
679
|
+
Thread.new do
|
|
680
|
+
loop do
|
|
681
|
+
file = begin
|
|
682
|
+
queue.pop(true)
|
|
683
|
+
rescue ThreadError
|
|
684
|
+
break # Queue is empty
|
|
685
|
+
end
|
|
686
|
+
|
|
687
|
+
begin
|
|
688
|
+
hf_hub_download(
|
|
689
|
+
repo_id: repo_id,
|
|
690
|
+
filename: file,
|
|
691
|
+
repo_type: repo_type,
|
|
692
|
+
revision: revision,
|
|
693
|
+
cache_dir: cache_dir,
|
|
694
|
+
force_download: force_download,
|
|
695
|
+
token: token,
|
|
696
|
+
local_files_only: false,
|
|
697
|
+
progress: nil # Individual file progress not supported in parallel mode
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
mutex.synchronize do
|
|
701
|
+
completed += 1
|
|
702
|
+
progress&.call(completed, total, (completed.to_f / total * 100).round(2)) if progress
|
|
703
|
+
end
|
|
704
|
+
rescue => e
|
|
705
|
+
warn "Failed to download #{file}: #{e.message}"
|
|
706
|
+
# Continue with other files
|
|
707
|
+
end
|
|
708
|
+
end
|
|
709
|
+
end
|
|
710
|
+
end
|
|
711
|
+
|
|
712
|
+
# Wait for all threads to complete
|
|
713
|
+
threads.each(&:join)
|
|
714
|
+
end
|
|
715
|
+
|
|
716
|
+
# Copies snapshot directory to local directory.
|
|
717
|
+
#
|
|
718
|
+
# @param snapshot_folder [Pathname] Source snapshot folder
|
|
719
|
+
# @param local_dir_path [Pathname] Destination local directory
|
|
720
|
+
def self.copy_snapshot_to_local_dir(snapshot_folder, local_dir_path)
|
|
721
|
+
return unless snapshot_folder.exist?
|
|
722
|
+
|
|
723
|
+
FileUtils.mkdir_p(local_dir_path)
|
|
724
|
+
|
|
725
|
+
# Copy all files and directories from snapshot to local_dir
|
|
726
|
+
snapshot_folder.children.each do |entry|
|
|
727
|
+
dest = local_dir_path.join(entry.basename)
|
|
728
|
+
|
|
729
|
+
if entry.symlink?
|
|
730
|
+
# For symlinks, copy the actual file content
|
|
731
|
+
target = entry.readlink
|
|
732
|
+
target = entry.dirname.join(target) unless target.absolute?
|
|
733
|
+
|
|
734
|
+
if target.file?
|
|
735
|
+
FileUtils.cp(target, dest)
|
|
736
|
+
end
|
|
737
|
+
elsif entry.directory?
|
|
738
|
+
FileUtils.cp_r(entry, dest)
|
|
739
|
+
elsif entry.file?
|
|
740
|
+
FileUtils.cp(entry, dest)
|
|
741
|
+
end
|
|
742
|
+
end
|
|
743
|
+
end
|
|
744
|
+
|
|
745
|
+
# Try to load a file from cache without downloading.
|
|
746
|
+
#
|
|
747
|
+
# This utility function checks if a file is available in the local cache
|
|
748
|
+
# and returns its path if found. Unlike `hf_hub_download` with `local_files_only=true`,
|
|
749
|
+
# this method returns `nil` instead of raising an error when the file is not cached.
|
|
750
|
+
#
|
|
751
|
+
# @param repo_id [String] Repository ID
|
|
752
|
+
# @param filename [String] File path in repository
|
|
753
|
+
# @param repo_type [String] Type of repository
|
|
754
|
+
# @param revision [String] Git revision
|
|
755
|
+
# @param cache_dir [String, Pathname, nil] Custom cache directory
|
|
756
|
+
# @return [Pathname, nil] Path to cached file, or nil if not found
|
|
757
|
+
#
|
|
758
|
+
# @example Check if file is cached
|
|
759
|
+
# path = FileDownload.try_to_load_from_cache(
|
|
760
|
+
# repo_id: "bert-base-uncased",
|
|
761
|
+
# filename: "config.json",
|
|
762
|
+
# revision: "main"
|
|
763
|
+
# )
|
|
764
|
+
# if path
|
|
765
|
+
# puts "File is cached at: #{path}"
|
|
766
|
+
# else
|
|
767
|
+
# puts "File not in cache"
|
|
768
|
+
# end
|
|
769
|
+
def self.try_to_load_from_cache(
|
|
770
|
+
repo_id:,
|
|
771
|
+
filename:,
|
|
772
|
+
repo_type: "model",
|
|
773
|
+
revision: "main",
|
|
774
|
+
cache_dir: nil
|
|
775
|
+
)
|
|
776
|
+
begin
|
|
777
|
+
hf_hub_download(
|
|
778
|
+
repo_id: repo_id,
|
|
779
|
+
filename: filename,
|
|
780
|
+
repo_type: repo_type,
|
|
781
|
+
revision: revision,
|
|
782
|
+
cache_dir: cache_dir,
|
|
783
|
+
local_files_only: true
|
|
784
|
+
)
|
|
785
|
+
rescue LocalEntryNotFoundError
|
|
786
|
+
nil
|
|
787
|
+
end
|
|
788
|
+
end
|
|
789
|
+
|
|
790
|
+
# Generate the HuggingFace Hub URL for a file in a repository.
|
|
791
|
+
#
|
|
792
|
+
# @param repo_id [String] Repository ID
|
|
793
|
+
# @param filename [String] File path in repository
|
|
794
|
+
# @param repo_type [String] Type of repository
|
|
795
|
+
# @param revision [String] Git revision (defaults to "main")
|
|
796
|
+
# @param endpoint [String, nil] Custom endpoint URL
|
|
797
|
+
# @return [String] Full URL to the file on HuggingFace Hub
|
|
798
|
+
#
|
|
799
|
+
# @example Generate URL for a model file
|
|
800
|
+
# url = FileDownload.hf_hub_url(
|
|
801
|
+
# repo_id: "bert-base-uncased",
|
|
802
|
+
# filename: "config.json"
|
|
803
|
+
# )
|
|
804
|
+
# # => "https://huggingface.co/bert-base-uncased/resolve/main/config.json"
|
|
805
|
+
#
|
|
806
|
+
# @example Generate URL for a dataset file
|
|
807
|
+
# url = FileDownload.hf_hub_url(
|
|
808
|
+
# repo_id: "squad",
|
|
809
|
+
# filename: "train.json",
|
|
810
|
+
# repo_type: "dataset",
|
|
811
|
+
# revision: "v1.0"
|
|
812
|
+
# )
|
|
813
|
+
def self.hf_hub_url(
|
|
814
|
+
repo_id:,
|
|
815
|
+
filename:,
|
|
816
|
+
repo_type: "model",
|
|
817
|
+
revision: "main",
|
|
818
|
+
endpoint: nil
|
|
819
|
+
)
|
|
820
|
+
repo_id = Utils::Validators.validate_repo_id(repo_id)
|
|
821
|
+
filename = Utils::Validators.validate_filename(filename)
|
|
822
|
+
repo_type = Utils::Validators.validate_repo_type(repo_type)
|
|
823
|
+
revision = Utils::Validators.validate_revision(revision)
|
|
824
|
+
|
|
825
|
+
endpoint ||= DurableHuggingfaceHub.configuration.endpoint
|
|
826
|
+
endpoint = endpoint.chomp("/")
|
|
827
|
+
|
|
828
|
+
"#{endpoint}/#{repo_type}s/#{repo_id}/resolve/#{revision}/#{filename}"
|
|
829
|
+
end
|
|
830
|
+
end
|
|
831
|
+
end
|