Skip to content

Commit 41d49d6

Browse files
authored
Resumable downloads. (#84)
* Resumable downloads. * Clippy. * Proof of resumability through manual corruption. * Speeding up downloads (less writes) + more accurate estimates (moving window) * Remove unwrap.
1 parent 57c58af commit 41d49d6

File tree

4 files changed

+387
-55
lines changed

4 files changed

+387
-55
lines changed

src/api/mod.rs

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use indicatif::{ProgressBar, ProgressStyle};
1+
use std::{collections::VecDeque, time::Duration};
2+
3+
use indicatif::{style::ProgressTracker, HumanBytes, ProgressBar, ProgressStyle};
24
use serde::Deserialize;
35

46
/// The asynchronous version of the API
@@ -35,9 +37,9 @@ impl Progress for ProgressBar {
3537
self.set_length(size as u64);
3638
self.set_style(
3739
ProgressStyle::with_template(
38-
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
39-
)
40-
.unwrap(), // .progress_chars("━ "),
40+
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec_smoothed} ({eta})",
41+
).unwrap().with_key("bytes_per_sec_smoothed", MovingAvgRate::default())
42+
,
4143
);
4244
let maxlength = 30;
4345
let message = if filename.len() > maxlength {
@@ -73,3 +75,48 @@ pub struct RepoInfo {
7375
/// The commit sha of the repo.
7476
pub sha: String,
7577
}
78+
79+
#[derive(Clone, Default)]
80+
struct MovingAvgRate {
81+
samples: VecDeque<(std::time::Instant, u64)>,
82+
}
83+
84+
impl ProgressTracker for MovingAvgRate {
85+
fn clone_box(&self) -> Box<dyn ProgressTracker> {
86+
Box::new(self.clone())
87+
}
88+
89+
fn tick(&mut self, state: &indicatif::ProgressState, now: std::time::Instant) {
90+
// sample at most every 20ms
91+
if self
92+
.samples
93+
.back()
94+
.map_or(true, |(prev, _)| (now - *prev) > Duration::from_millis(20))
95+
{
96+
self.samples.push_back((now, state.pos()));
97+
}
98+
99+
while let Some(first) = self.samples.front() {
100+
if now - first.0 > Duration::from_secs(1) {
101+
self.samples.pop_front();
102+
} else {
103+
break;
104+
}
105+
}
106+
}
107+
108+
fn reset(&mut self, _state: &indicatif::ProgressState, _now: std::time::Instant) {
109+
self.samples = Default::default();
110+
}
111+
112+
fn write(&self, _state: &indicatif::ProgressState, w: &mut dyn std::fmt::Write) {
113+
match (self.samples.front(), self.samples.back()) {
114+
(Some((t0, p0)), Some((t1, p1))) if self.samples.len() > 1 => {
115+
let elapsed_ms = (*t1 - *t0).as_millis();
116+
let rate = ((p1 - p0) as f64 * 1000f64 / elapsed_ms as f64) as u64;
117+
write!(w, "{}/s", HumanBytes(rate)).unwrap()
118+
}
119+
_ => write!(w, "-").unwrap(),
120+
}
121+
}
122+
}

src/api/sync.rs

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ const AUTHORIZATION: &str = "Authorization";
2828
type HeaderMap = HashMap<&'static str, String>;
2929
type HeaderName = &'static str;
3030

31+
/// Specific name for the sync part of the resumable file
32+
const EXTENTION: &str = ".part";
33+
3134
struct Wrapper<'a, P: Progress, R: Read> {
3235
progress: &'a mut P,
3336
inner: R,
@@ -104,6 +107,10 @@ pub enum ApiError {
104107
#[error("Native tls: {0}")]
105108
#[cfg(feature = "native-tls")]
106109
Native(#[from] native_tls::Error),
110+
111+
/// The part file is corrupted
112+
#[error("Invalid part file - corrupted file")]
113+
InvalidResume,
107114
}
108115

109116
/// Helper to create [`Api`] with all the options.
@@ -436,15 +443,26 @@ impl Api {
436443
url: &str,
437444
size: usize,
438445
mut progress: P,
446+
tmp_path: PathBuf,
439447
filename: &str,
440448
) -> Result<PathBuf, ApiError> {
441449
progress.init(size, filename);
442-
let filepath = self.cache.temp_path();
450+
let filepath = tmp_path;
443451

444452
// Create the file and set everything properly
445-
let mut file = std::fs::File::create(&filepath)?;
446453

447-
let mut res = self.download_from(url, 0u64, size, &mut file, filename, &mut progress);
454+
let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) {
455+
Ok(f) => f,
456+
Err(_) => std::fs::File::create(&filepath)?,
457+
};
458+
459+
// In case of resume.
460+
let start = file.metadata()?.len();
461+
if start > size as u64 {
462+
return Err(ApiError::InvalidResume);
463+
}
464+
465+
let mut res = self.download_from(url, start, size, &mut file, filename, &mut progress);
448466
if self.max_retries > 0 {
449467
let mut i = 0;
450468
while let Err(dlerr) = res {
@@ -631,9 +649,11 @@ impl ApiRepo {
631649
.blob_path(&metadata.etag);
632650
std::fs::create_dir_all(blob_path.parent().unwrap())?;
633651

634-
let tmp_filename = self
635-
.api
636-
.download_tempfile(&url, metadata.size, progress, filename)?;
652+
let mut tmp_path = blob_path.clone();
653+
tmp_path.set_extension(EXTENTION);
654+
let tmp_filename =
655+
self.api
656+
.download_tempfile(&url, metadata.size, progress, tmp_path, filename)?;
637657

638658
std::fs::rename(tmp_filename, &blob_path)?;
639659
let mut pointer_path = self
@@ -704,6 +724,7 @@ mod tests {
704724
use rand::{distributions::Alphanumeric, Rng};
705725
use serde_json::{json, Value};
706726
use sha2::{Digest, Sha256};
727+
use std::io::{Seek, SeekFrom, Write};
707728

708729
struct TempDir {
709730
path: PathBuf,
@@ -756,6 +777,85 @@ mod tests {
756777
assert_eq!(cache_path, downloaded_path);
757778
}
758779

780+
#[test]
781+
fn resume() {
782+
let tmp = TempDir::new();
783+
let api = ApiBuilder::new()
784+
.with_progress(false)
785+
.with_cache_dir(tmp.path.clone())
786+
.build()
787+
.unwrap();
788+
789+
let model_id = "julien-c/dummy-unknown".to_string();
790+
let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
791+
assert!(downloaded_path.exists());
792+
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
793+
assert_eq!(
794+
val[..],
795+
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
796+
);
797+
798+
let blob = std::fs::canonicalize(&downloaded_path).unwrap();
799+
let file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
800+
let size = file.metadata().unwrap().len();
801+
let truncate: f32 = rand::random();
802+
let new_size = (size as f32 * truncate) as u64;
803+
file.set_len(new_size).unwrap();
804+
let mut blob_part = blob.clone();
805+
blob_part.set_extension(".part");
806+
std::fs::rename(blob, &blob_part).unwrap();
807+
std::fs::remove_file(&downloaded_path).unwrap();
808+
let content = std::fs::read(&*blob_part).unwrap();
809+
assert_eq!(content.len() as u64, new_size);
810+
let val = Sha256::digest(content);
811+
// We modified the sha.
812+
assert!(
813+
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
814+
);
815+
let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
816+
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
817+
assert_eq!(downloaded_path, new_downloaded_path);
818+
assert_eq!(
819+
val[..],
820+
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
821+
);
822+
823+
// Here we prove the previous part was correctly resuming by purposefully corrupting the
824+
// file.
825+
let blob = std::fs::canonicalize(&downloaded_path).unwrap();
826+
let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
827+
let size = file.metadata().unwrap().len();
828+
// Not random for consistent sha corruption
829+
let truncate: f32 = 0.5;
830+
let new_size = (size as f32 * truncate) as u64;
831+
// Truncating
832+
file.set_len(new_size).unwrap();
833+
// Corrupting by changing a single byte.
834+
file.seek(SeekFrom::Start(new_size - 1)).unwrap();
835+
file.write_all(&[0]).unwrap();
836+
837+
let mut blob_part = blob.clone();
838+
blob_part.set_extension(".part");
839+
std::fs::rename(blob, &blob_part).unwrap();
840+
std::fs::remove_file(&downloaded_path).unwrap();
841+
let content = std::fs::read(&*blob_part).unwrap();
842+
assert_eq!(content.len() as u64, new_size);
843+
let val = Sha256::digest(content);
844+
// We modified the sha.
845+
assert!(
846+
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
847+
);
848+
let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
849+
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
850+
println!("Sha {val:#x}");
851+
assert_eq!(downloaded_path, new_downloaded_path);
852+
assert_eq!(
853+
val[..],
854+
// Corrupted sha
855+
hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5")
856+
);
857+
}
858+
759859
#[test]
760860
fn simple_with_retries() {
761861
let tmp = TempDir::new();

0 commit comments

Comments
 (0)