Skip to content

Commit b00375c

Browse files
authored
Fixed refs file and better offline model handling (#7)
* If metadata request fail, check if file exists * Properly commit retrieval * Allow the usage of outdated models if no connection * Refactor for offline models
1 parent b35a584 commit b00375c

File tree

2 files changed

+147
-60
lines changed

2 files changed

+147
-60
lines changed

include/huggingface_hub.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ namespace huggingface_hub {
4545
* This structure contains the SHA-256 hash, commit ID, and size of a file.
4646
*/
4747
struct FileMetadata {
48-
std::string commit; /**< Commit ID of the file */
4948
std::string type; /**< Type of the file (e.g., "model", "dataset") */
5049
std::string oid; /**< Object ID of the file */
5150
uint64_t size; /**< Size of the file in bytes */

src/huggingface_hub.cpp

Lines changed: 147 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,37 @@ std::filesystem::path expand_user_home(const std::string &path) {
8585
return std::filesystem::path(path);
8686
}
8787

88-
std::string create_cache_system(const std::string &cache_dir,
89-
const std::string &repo_id) {
90-
std::string model_folder = std::string("models/" + repo_id);
88+
std::string get_model_repo_path(const std::string &repo_id) {
89+
std::string model_folder = "models/" + repo_id;
9190

9291
size_t pos = 0;
9392
while ((pos = model_folder.find("/", pos)) != std::string::npos) {
9493
model_folder.replace(pos, 1, "--");
9594
pos += 2;
9695
}
9796

97+
return model_folder;
98+
}
99+
100+
std::string find_outdated_file(const std::string &snapshot_dir,
101+
const std::string &filename) {
102+
for (const auto &version :
103+
std::filesystem::directory_iterator(snapshot_dir)) {
104+
for (const auto &file :
105+
std::filesystem::directory_iterator(version.path())) {
106+
if (file.path().filename() == filename) {
107+
return file.path();
108+
break;
109+
}
110+
}
111+
}
112+
return "";
113+
}
114+
115+
std::string create_cache_system(const std::string &cache_dir,
116+
const std::string &repo_id) {
117+
std::string model_folder = get_model_repo_path(repo_id);
118+
98119
std::string expanded_cache_dir = expand_user_home(cache_dir);
99120

100121
std::string model_cache_path = expanded_cache_dir + "/" + model_folder + "/";
@@ -112,7 +133,7 @@ std::string create_cache_system(const std::string &cache_dir,
112133

113134
size_t write_string_data(void *ptr, size_t size, size_t nmemb, void *stream) {
114135
if (!stream) {
115-
log_error("Error: stream is null!");
136+
log_error("Stream is null!");
116137
return 0;
117138
}
118139
std::string *out = static_cast<std::string *>(stream);
@@ -122,12 +143,12 @@ size_t write_string_data(void *ptr, size_t size, size_t nmemb, void *stream) {
122143

123144
size_t write_file_data(void *ptr, size_t size, size_t nmemb, void *stream) {
124145
if (!stream) {
125-
log_error("Error: stream is null!");
146+
log_error("Stream is null!");
126147
return 0;
127148
}
128149
std::ofstream *out = static_cast<std::ofstream *>(stream);
129150
if (!out->is_open()) {
130-
log_error("Error: output file stream is not open!");
151+
log_error("Output file stream is not open!");
131152
return 0;
132153
}
133154
out->write(static_cast<char *>(ptr), size * nmemb);
@@ -161,16 +182,68 @@ FileMetadata extract_metadata(const std::string &json) {
161182
R"(\"lfs\"\s*:\s*\{[^}]*\"oid\"\s*:\s*\"([a-f0-9]{64})\")")))
162183
metadata.sha256 = match[1];
163184

164-
// Extract "commit" ID
165-
if (std::regex_search(
166-
json, match,
167-
std::regex(
168-
R"(\"lastCommit\"\s*:\s*\{[^}]*\"id\"\s*:\s*\"([a-f0-9]{40})\")")))
169-
metadata.commit = match[1];
170-
171185
return metadata;
172186
}
173187

188+
std::string get_file_path(const std::string &cache_dir,
189+
const std::string &repo_id, const std::string &file) {
190+
std::string model_folder = get_model_repo_path(repo_id);
191+
192+
std::filesystem::path expanded_cache_dir = expand_user_home(cache_dir);
193+
std::filesystem::path refs_file_path =
194+
expanded_cache_dir / model_folder / "refs" / "main";
195+
196+
if (!std::filesystem::exists(refs_file_path)) {
197+
log_debug("refs file does not exist");
198+
return "";
199+
}
200+
std::ifstream refs_file(refs_file_path);
201+
std::string commit;
202+
refs_file >> commit;
203+
refs_file.close();
204+
std::filesystem::path snapshot_file_path =
205+
expanded_cache_dir / model_folder / "snapshots" / commit / file;
206+
if (std::filesystem::exists(snapshot_file_path)) {
207+
return snapshot_file_path.string();
208+
} else {
209+
return ""; // File does not exist
210+
}
211+
}
212+
213+
std::variant<std::string, CURLcode> get_model_commit(const std::string &repo) {
214+
CURL *curl = curl_easy_init();
215+
if (!curl) {
216+
return CURLE_FAILED_INIT;
217+
}
218+
219+
std::string url =
220+
"https://huggingface.co/api/models/" + repo + "/revision/main";
221+
std::string response;
222+
223+
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
224+
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, NULL);
225+
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_string_data);
226+
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response);
227+
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
228+
curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
229+
curl_easy_setopt(curl, CURLOPT_HEADER, 0L);
230+
231+
CURLcode res = curl_easy_perform(curl);
232+
curl_easy_cleanup(curl);
233+
if (res != CURLE_OK) {
234+
return res;
235+
}
236+
237+
std::smatch match;
238+
std::regex pattern("\"sha\"\\s*:\\s*\"([a-fA-F0-9]{40})\"");
239+
240+
if (std::regex_search(response, match, pattern) && match.size() > 1) {
241+
return match[1];
242+
} else {
243+
return std::string(); // Return empty string if not found
244+
}
245+
}
246+
174247
std::variant<struct FileMetadata, CURLcode>
175248
get_model_metadata_from_hf(const std::string &repo, const std::string &file) {
176249
CURL *curl = curl_easy_init();
@@ -205,6 +278,10 @@ get_model_metadata_from_hf(const std::string &repo, const std::string &file) {
205278
return res;
206279
}
207280

281+
if (response.empty() || response == "[]") {
282+
return CURLE_REMOTE_FILE_NOT_FOUND;
283+
}
284+
208285
return extract_metadata(response);
209286
}
210287

@@ -287,7 +364,6 @@ CURLcode perform_download(std::string url,
287364
std::ios::binary | std::ios::app);
288365

289366
if (!file.is_open()) {
290-
log_error("Error: failed to open file stream!");
291367
return CURLE_FAILED_INIT;
292368
}
293369

@@ -328,51 +404,70 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
328404
struct DownloadResult result;
329405
result.success = true;
330406

331-
// 1. Check that model exists on Hugging Face
332-
auto metadata_result = get_model_metadata_from_hf(repo_id, filename);
333-
if (std::holds_alternative<CURLcode>(metadata_result)) {
334-
CURLcode err = std::get<CURLcode>(metadata_result);
407+
log_info("Downloading " + filename + " from " + repo_id);
335408

336-
std::string refs_main_path = "models/" + repo_id;
337-
size_t pos = 0;
338-
while ((pos = refs_main_path.find("/", pos)) != std::string::npos) {
339-
refs_main_path.replace(pos, 1, "--");
340-
pos += 2;
409+
// Check repo (accessibility and version)
410+
auto commit_result = get_model_commit(repo_id);
411+
412+
if (std::holds_alternative<CURLcode>(commit_result)) {
413+
CURLcode err = std::get<CURLcode>(commit_result);
414+
415+
std::string file_path = get_file_path(cache_dir, repo_id, filename);
416+
if (!file_path.empty()) {
417+
log_info("Using cached file.");
418+
result.path = file_path;
419+
result.success = true;
420+
return result;
341421
}
342422

343-
std::filesystem::path cache_model_dir =
344-
expand_user_home("~/.cache/huggingface/hub/" + refs_main_path + "/");
345-
std::filesystem::path refs_file_path = cache_model_dir / "refs/main";
346-
347-
if (std::filesystem::exists(refs_file_path)) {
348-
std::ifstream refs_file(refs_file_path);
349-
std::string commit;
350-
refs_file >> commit;
351-
refs_file.close();
352-
353-
std::filesystem::path snapshot_file_path =
354-
cache_model_dir / "snapshots" / commit / filename;
355-
if (std::filesystem::exists(snapshot_file_path)) {
356-
log_info("Snapshot file exists. Skipping download...");
357-
result.success = true;
358-
result.path = snapshot_file_path;
359-
return result;
360-
}
423+
std::string model_path = get_model_repo_path(repo_id);
424+
std::string snapshot_path =
425+
expand_user_home(cache_dir + "/" + model_path + "/snapshots");
426+
if (!std::filesystem::exists(snapshot_path)) {
427+
log_info(snapshot_path);
428+
log_error("Repo not found (locally nor online): " + repo_id);
429+
result.success = false;
430+
return result;
361431
}
362432

363-
log_error("CURL metadata request failed: " +
433+
std::string outdated_file = find_outdated_file(snapshot_path, filename);
434+
if (!outdated_file.empty()) {
435+
log_info("Using outdated cached file " + outdated_file);
436+
result.path = outdated_file;
437+
result.success = true;
438+
return result;
439+
}
440+
441+
log_error("Error getting model: " + std::string(curl_easy_strerror(err)));
442+
result.success = false;
443+
return result;
444+
}
445+
446+
std::string latest_commit = std::get<std::string>(commit_result);
447+
if (latest_commit.empty()) {
448+
log_error("Failed to retrieve the latest commit for repository: " +
449+
repo_id);
450+
result.success = false;
451+
return result;
452+
}
453+
454+
// Check file accessibility
455+
auto metadata_result = get_model_metadata_from_hf(repo_id, filename);
456+
457+
if (std::holds_alternative<CURLcode>(metadata_result)) {
458+
CURLcode err = std::get<CURLcode>(metadata_result);
459+
log_error("Error getting metadata: " +
364460
std::string(curl_easy_strerror(err)));
365461
result.success = false;
366462
return result;
367463
}
368464

369-
// 2. Create Cache Dir Struct
465+
// Create Cache Dir Struct
370466
std::string cache_model_dir = create_cache_system(cache_dir, repo_id);
371467
log_debug("Cache directory: " + cache_model_dir);
372-
log_info("Downloading " + filename + " from " + repo_id);
373468

374469
struct FileMetadata metadata = std::get<struct FileMetadata>(metadata_result);
375-
log_debug("Commit: " + metadata.commit);
470+
log_debug("Commit: " + latest_commit);
376471
log_debug("Blob ID: " + metadata.oid);
377472
log_debug("Size: " + std::to_string(metadata.size) + " bytes");
378473
log_debug("SHA256: " + metadata.sha256);
@@ -391,29 +486,22 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
391486
}
392487

393488
std::filesystem::path snapshot_file_path(cache_model_dir + "snapshots/" +
394-
metadata.commit + "/" + filename);
489+
latest_commit + "/" + filename);
395490
std::filesystem::path refs_file_path(cache_model_dir + "refs/main");
396491

397492
result.path = snapshot_file_path;
398493

494+
std::ofstream refs_file(refs_file_path);
495+
refs_file << latest_commit << std::endl;
496+
refs_file.close();
497+
399498
if (std::filesystem::exists(snapshot_file_path) &&
400499
std::filesystem::exists(blob_file_path) && !force_download) {
401-
log_info("Snapshot file exists. Skipping download...");
500+
log_info("Snapshot file exists. Using cached file.");
402501
return result;
403502
}
404503

405-
if (std::filesystem::exists(refs_file_path)) {
406-
std::ifstream refs_file(refs_file_path);
407-
std::string commit;
408-
refs_file >> commit;
409-
refs_file.close();
410-
} else {
411-
std::ofstream refs_file(refs_file_path);
412-
refs_file << metadata.commit;
413-
refs_file.close();
414-
}
415-
416-
// 3. Download the file
504+
// 4. Download the file
417505
std::string url =
418506
"https://huggingface.co/" + repo_id + "/resolve/main/" + filename;
419507
std::filesystem::create_directories(snapshot_file_path.parent_path());

0 commit comments

Comments
 (0)