@@ -85,16 +85,37 @@ std::filesystem::path expand_user_home(const std::string &path) {
85
85
return std::filesystem::path (path);
86
86
}
87
87
88
- std::string create_cache_system (const std::string &cache_dir,
89
- const std::string &repo_id) {
90
- std::string model_folder = std::string (" models/" + repo_id);
88
+ std::string get_model_repo_path (const std::string &repo_id) {
89
+ std::string model_folder = " models/" + repo_id;
91
90
92
91
size_t pos = 0 ;
93
92
while ((pos = model_folder.find (" /" , pos)) != std::string::npos) {
94
93
model_folder.replace (pos, 1 , " --" );
95
94
pos += 2 ;
96
95
}
97
96
97
+ return model_folder;
98
+ }
99
+
100
+ std::string find_outdated_file (const std::string &snapshot_dir,
101
+ const std::string &filename) {
102
+ for (const auto &version :
103
+ std::filesystem::directory_iterator (snapshot_dir)) {
104
+ for (const auto &file :
105
+ std::filesystem::directory_iterator (version.path ())) {
106
+ if (file.path ().filename () == filename) {
107
+ return file.path ();
108
+ break ;
109
+ }
110
+ }
111
+ }
112
+ return " " ;
113
+ }
114
+
115
+ std::string create_cache_system (const std::string &cache_dir,
116
+ const std::string &repo_id) {
117
+ std::string model_folder = get_model_repo_path (repo_id);
118
+
98
119
std::string expanded_cache_dir = expand_user_home (cache_dir);
99
120
100
121
std::string model_cache_path = expanded_cache_dir + " /" + model_folder + " /" ;
@@ -112,7 +133,7 @@ std::string create_cache_system(const std::string &cache_dir,
112
133
113
134
size_t write_string_data (void *ptr, size_t size, size_t nmemb, void *stream) {
114
135
if (!stream) {
115
- log_error (" Error: stream is null!" );
136
+ log_error (" Stream is null!" );
116
137
return 0 ;
117
138
}
118
139
std::string *out = static_cast <std::string *>(stream);
@@ -122,12 +143,12 @@ size_t write_string_data(void *ptr, size_t size, size_t nmemb, void *stream) {
122
143
123
144
size_t write_file_data (void *ptr, size_t size, size_t nmemb, void *stream) {
124
145
if (!stream) {
125
- log_error (" Error: stream is null!" );
146
+ log_error (" Stream is null!" );
126
147
return 0 ;
127
148
}
128
149
std::ofstream *out = static_cast <std::ofstream *>(stream);
129
150
if (!out->is_open ()) {
130
- log_error (" Error: output file stream is not open!" );
151
+ log_error (" Output file stream is not open!" );
131
152
return 0 ;
132
153
}
133
154
out->write (static_cast <char *>(ptr), size * nmemb);
@@ -161,16 +182,68 @@ FileMetadata extract_metadata(const std::string &json) {
161
182
R"( \"lfs\"\s*:\s*\{[^}]*\"oid\"\s*:\s*\"([a-f0-9]{64})\")" )))
162
183
metadata.sha256 = match[1 ];
163
184
164
- // Extract "commit" ID
165
- if (std::regex_search (
166
- json, match,
167
- std::regex (
168
- R"( \"lastCommit\"\s*:\s*\{[^}]*\"id\"\s*:\s*\"([a-f0-9]{40})\")" )))
169
- metadata.commit = match[1 ];
170
-
171
185
return metadata;
172
186
}
173
187
188
+ std::string get_file_path (const std::string &cache_dir,
189
+ const std::string &repo_id, const std::string &file) {
190
+ std::string model_folder = get_model_repo_path (repo_id);
191
+
192
+ std::filesystem::path expanded_cache_dir = expand_user_home (cache_dir);
193
+ std::filesystem::path refs_file_path =
194
+ expanded_cache_dir / model_folder / " refs" / " main" ;
195
+
196
+ if (!std::filesystem::exists (refs_file_path)) {
197
+ log_debug (" refs file does not exist" );
198
+ return " " ;
199
+ }
200
+ std::ifstream refs_file (refs_file_path);
201
+ std::string commit;
202
+ refs_file >> commit;
203
+ refs_file.close ();
204
+ std::filesystem::path snapshot_file_path =
205
+ expanded_cache_dir / model_folder / " snapshots" / commit / file;
206
+ if (std::filesystem::exists (snapshot_file_path)) {
207
+ return snapshot_file_path.string ();
208
+ } else {
209
+ return " " ; // File does not exist
210
+ }
211
+ }
212
+
213
+ std::variant<std::string, CURLcode> get_model_commit (const std::string &repo) {
214
+ CURL *curl = curl_easy_init ();
215
+ if (!curl) {
216
+ return CURLE_FAILED_INIT;
217
+ }
218
+
219
+ std::string url =
220
+ " https://huggingface.co/api/models/" + repo + " /revision/main" ;
221
+ std::string response;
222
+
223
+ curl_easy_setopt (curl, CURLOPT_URL, url.c_str ());
224
+ curl_easy_setopt (curl, CURLOPT_HTTPHEADER, NULL );
225
+ curl_easy_setopt (curl, CURLOPT_WRITEFUNCTION, write_string_data);
226
+ curl_easy_setopt (curl, CURLOPT_WRITEDATA, &response);
227
+ curl_easy_setopt (curl, CURLOPT_FOLLOWLOCATION, 1L );
228
+ curl_easy_setopt (curl, CURLOPT_FAILONERROR, 1L );
229
+ curl_easy_setopt (curl, CURLOPT_HEADER, 0L );
230
+
231
+ CURLcode res = curl_easy_perform (curl);
232
+ curl_easy_cleanup (curl);
233
+ if (res != CURLE_OK) {
234
+ return res;
235
+ }
236
+
237
+ std::smatch match;
238
+ std::regex pattern (" \" sha\"\\ s*:\\ s*\" ([a-fA-F0-9]{40})\" " );
239
+
240
+ if (std::regex_search (response, match, pattern) && match.size () > 1 ) {
241
+ return match[1 ];
242
+ } else {
243
+ return std::string (); // Return empty string if not found
244
+ }
245
+ }
246
+
174
247
std::variant<struct FileMetadata , CURLcode>
175
248
get_model_metadata_from_hf (const std::string &repo, const std::string &file) {
176
249
CURL *curl = curl_easy_init ();
@@ -205,6 +278,10 @@ get_model_metadata_from_hf(const std::string &repo, const std::string &file) {
205
278
return res;
206
279
}
207
280
281
+ if (response.empty () || response == " []" ) {
282
+ return CURLE_REMOTE_FILE_NOT_FOUND;
283
+ }
284
+
208
285
return extract_metadata (response);
209
286
}
210
287
@@ -287,7 +364,6 @@ CURLcode perform_download(std::string url,
287
364
std::ios::binary | std::ios::app);
288
365
289
366
if (!file.is_open ()) {
290
- log_error (" Error: failed to open file stream!" );
291
367
return CURLE_FAILED_INIT;
292
368
}
293
369
@@ -328,51 +404,70 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
328
404
struct DownloadResult result;
329
405
result.success = true ;
330
406
331
- // 1. Check that model exists on Hugging Face
332
- auto metadata_result = get_model_metadata_from_hf (repo_id, filename);
333
- if (std::holds_alternative<CURLcode>(metadata_result)) {
334
- CURLcode err = std::get<CURLcode>(metadata_result);
407
+ log_info (" Downloading " + filename + " from " + repo_id);
335
408
336
- std::string refs_main_path = " models/" + repo_id;
337
- size_t pos = 0 ;
338
- while ((pos = refs_main_path.find (" /" , pos)) != std::string::npos) {
339
- refs_main_path.replace (pos, 1 , " --" );
340
- pos += 2 ;
409
+ // Check repo (accessibility and version)
410
+ auto commit_result = get_model_commit (repo_id);
411
+
412
+ if (std::holds_alternative<CURLcode>(commit_result)) {
413
+ CURLcode err = std::get<CURLcode>(commit_result);
414
+
415
+ std::string file_path = get_file_path (cache_dir, repo_id, filename);
416
+ if (!file_path.empty ()) {
417
+ log_info (" Using cached file." );
418
+ result.path = file_path;
419
+ result.success = true ;
420
+ return result;
341
421
}
342
422
343
- std::filesystem::path cache_model_dir =
344
- expand_user_home (" ~/.cache/huggingface/hub/" + refs_main_path + " /" );
345
- std::filesystem::path refs_file_path = cache_model_dir / " refs/main" ;
346
-
347
- if (std::filesystem::exists (refs_file_path)) {
348
- std::ifstream refs_file (refs_file_path);
349
- std::string commit;
350
- refs_file >> commit;
351
- refs_file.close ();
352
-
353
- std::filesystem::path snapshot_file_path =
354
- cache_model_dir / " snapshots" / commit / filename;
355
- if (std::filesystem::exists (snapshot_file_path)) {
356
- log_info (" Snapshot file exists. Skipping download..." );
357
- result.success = true ;
358
- result.path = snapshot_file_path;
359
- return result;
360
- }
423
+ std::string model_path = get_model_repo_path (repo_id);
424
+ std::string snapshot_path =
425
+ expand_user_home (cache_dir + " /" + model_path + " /snapshots" );
426
+ if (!std::filesystem::exists (snapshot_path)) {
427
+ log_info (snapshot_path);
428
+ log_error (" Repo not found (locally nor online): " + repo_id);
429
+ result.success = false ;
430
+ return result;
361
431
}
362
432
363
- log_error (" CURL metadata request failed: " +
433
+ std::string outdated_file = find_outdated_file (snapshot_path, filename);
434
+ if (!outdated_file.empty ()) {
435
+ log_info (" Using outdated cached file " + outdated_file);
436
+ result.path = outdated_file;
437
+ result.success = true ;
438
+ return result;
439
+ }
440
+
441
+ log_error (" Error getting model: " + std::string (curl_easy_strerror (err)));
442
+ result.success = false ;
443
+ return result;
444
+ }
445
+
446
+ std::string latest_commit = std::get<std::string>(commit_result);
447
+ if (latest_commit.empty ()) {
448
+ log_error (" Failed to retrieve the latest commit for repository: " +
449
+ repo_id);
450
+ result.success = false ;
451
+ return result;
452
+ }
453
+
454
+ // Check file accessibility
455
+ auto metadata_result = get_model_metadata_from_hf (repo_id, filename);
456
+
457
+ if (std::holds_alternative<CURLcode>(metadata_result)) {
458
+ CURLcode err = std::get<CURLcode>(metadata_result);
459
+ log_error (" Error getting metadata: " +
364
460
std::string (curl_easy_strerror (err)));
365
461
result.success = false ;
366
462
return result;
367
463
}
368
464
369
- // 2. Create Cache Dir Struct
465
+ // Create Cache Dir Struct
370
466
std::string cache_model_dir = create_cache_system (cache_dir, repo_id);
371
467
log_debug (" Cache directory: " + cache_model_dir);
372
- log_info (" Downloading " + filename + " from " + repo_id);
373
468
374
469
struct FileMetadata metadata = std::get<struct FileMetadata >(metadata_result);
375
- log_debug (" Commit: " + metadata. commit );
470
+ log_debug (" Commit: " + latest_commit );
376
471
log_debug (" Blob ID: " + metadata.oid );
377
472
log_debug (" Size: " + std::to_string (metadata.size ) + " bytes" );
378
473
log_debug (" SHA256: " + metadata.sha256 );
@@ -391,29 +486,22 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
391
486
}
392
487
393
488
std::filesystem::path snapshot_file_path (cache_model_dir + " snapshots/" +
394
- metadata. commit + " /" + filename);
489
+ latest_commit + " /" + filename);
395
490
std::filesystem::path refs_file_path (cache_model_dir + " refs/main" );
396
491
397
492
result.path = snapshot_file_path;
398
493
494
+ std::ofstream refs_file (refs_file_path);
495
+ refs_file << latest_commit << std::endl;
496
+ refs_file.close ();
497
+
399
498
if (std::filesystem::exists (snapshot_file_path) &&
400
499
std::filesystem::exists (blob_file_path) && !force_download) {
401
- log_info (" Snapshot file exists. Skipping download.. ." );
500
+ log_info (" Snapshot file exists. Using cached file ." );
402
501
return result;
403
502
}
404
503
405
- if (std::filesystem::exists (refs_file_path)) {
406
- std::ifstream refs_file (refs_file_path);
407
- std::string commit;
408
- refs_file >> commit;
409
- refs_file.close ();
410
- } else {
411
- std::ofstream refs_file (refs_file_path);
412
- refs_file << metadata.commit ;
413
- refs_file.close ();
414
- }
415
-
416
- // 3. Download the file
504
+ // 4. Download the file
417
505
std::string url =
418
506
" https://huggingface.co/" + repo_id + " /resolve/main/" + filename;
419
507
std::filesystem::create_directories (snapshot_file_path.parent_path ());
0 commit comments