@@ -28,6 +28,9 @@ const AUTHORIZATION: &str = "Authorization";
2828type HeaderMap = HashMap < & ' static str , String > ;
2929type HeaderName = & ' static str ;
3030
31+ /// Specific name for the sync part of the resumable file
32+ const EXTENTION : & str = ".part" ;
33+
3134struct Wrapper < ' a , P : Progress , R : Read > {
3235 progress : & ' a mut P ,
3336 inner : R ,
@@ -104,6 +107,10 @@ pub enum ApiError {
104107 #[ error( "Native tls: {0}" ) ]
105108 #[ cfg( feature = "native-tls" ) ]
106109 Native ( #[ from] native_tls:: Error ) ,
110+
111+ /// The part file is corrupted
112+ #[ error( "Invalid part file - corrupted file" ) ]
113+ InvalidResume ,
107114}
108115
109116/// Helper to create [`Api`] with all the options.
@@ -436,15 +443,26 @@ impl Api {
436443 url : & str ,
437444 size : usize ,
438445 mut progress : P ,
446+ tmp_path : PathBuf ,
439447 filename : & str ,
440448 ) -> Result < PathBuf , ApiError > {
441449 progress. init ( size, filename) ;
442- let filepath = self . cache . temp_path ( ) ;
450+ let filepath = tmp_path ;
443451
444452 // Create the file and set everything properly
445- let mut file = std:: fs:: File :: create ( & filepath) ?;
446453
447- let mut res = self . download_from ( url, 0u64 , size, & mut file, filename, & mut progress) ;
454+ let mut file = match std:: fs:: OpenOptions :: new ( ) . append ( true ) . open ( & filepath) {
455+ Ok ( f) => f,
456+ Err ( _) => std:: fs:: File :: create ( & filepath) ?,
457+ } ;
458+
459+ // In case of resume.
460+ let start = file. metadata ( ) ?. len ( ) ;
461+ if start > size as u64 {
462+ return Err ( ApiError :: InvalidResume ) ;
463+ }
464+
465+ let mut res = self . download_from ( url, start, size, & mut file, filename, & mut progress) ;
448466 if self . max_retries > 0 {
449467 let mut i = 0 ;
450468 while let Err ( dlerr) = res {
@@ -631,9 +649,11 @@ impl ApiRepo {
631649 . blob_path ( & metadata. etag ) ;
632650 std:: fs:: create_dir_all ( blob_path. parent ( ) . unwrap ( ) ) ?;
633651
634- let tmp_filename = self
635- . api
636- . download_tempfile ( & url, metadata. size , progress, filename) ?;
652+ let mut tmp_path = blob_path. clone ( ) ;
653+ tmp_path. set_extension ( EXTENTION ) ;
654+ let tmp_filename =
655+ self . api
656+ . download_tempfile ( & url, metadata. size , progress, tmp_path, filename) ?;
637657
638658 std:: fs:: rename ( tmp_filename, & blob_path) ?;
639659 let mut pointer_path = self
@@ -704,6 +724,7 @@ mod tests {
704724 use rand:: { distributions:: Alphanumeric , Rng } ;
705725 use serde_json:: { json, Value } ;
706726 use sha2:: { Digest , Sha256 } ;
727+ use std:: io:: { Seek , SeekFrom , Write } ;
707728
708729 struct TempDir {
709730 path : PathBuf ,
@@ -756,6 +777,85 @@ mod tests {
756777 assert_eq ! ( cache_path, downloaded_path) ;
757778 }
758779
780+ #[ test]
781+ fn resume ( ) {
782+ let tmp = TempDir :: new ( ) ;
783+ let api = ApiBuilder :: new ( )
784+ . with_progress ( false )
785+ . with_cache_dir ( tmp. path . clone ( ) )
786+ . build ( )
787+ . unwrap ( ) ;
788+
789+ let model_id = "julien-c/dummy-unknown" . to_string ( ) ;
790+ let downloaded_path = api. model ( model_id. clone ( ) ) . download ( "config.json" ) . unwrap ( ) ;
791+ assert ! ( downloaded_path. exists( ) ) ;
792+ let val = Sha256 :: digest ( std:: fs:: read ( & * downloaded_path) . unwrap ( ) ) ;
793+ assert_eq ! (
794+ val[ ..] ,
795+ hex!( "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" )
796+ ) ;
797+
798+ let blob = std:: fs:: canonicalize ( & downloaded_path) . unwrap ( ) ;
799+ let file = std:: fs:: OpenOptions :: new ( ) . write ( true ) . open ( & blob) . unwrap ( ) ;
800+ let size = file. metadata ( ) . unwrap ( ) . len ( ) ;
801+ let truncate: f32 = rand:: random ( ) ;
802+ let new_size = ( size as f32 * truncate) as u64 ;
803+ file. set_len ( new_size) . unwrap ( ) ;
804+ let mut blob_part = blob. clone ( ) ;
805+ blob_part. set_extension ( ".part" ) ;
806+ std:: fs:: rename ( blob, & blob_part) . unwrap ( ) ;
807+ std:: fs:: remove_file ( & downloaded_path) . unwrap ( ) ;
808+ let content = std:: fs:: read ( & * blob_part) . unwrap ( ) ;
809+ assert_eq ! ( content. len( ) as u64 , new_size) ;
810+ let val = Sha256 :: digest ( content) ;
811+ // We modified the sha.
812+ assert ! (
813+ val[ ..] != hex!( "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" )
814+ ) ;
815+ let new_downloaded_path = api. model ( model_id. clone ( ) ) . download ( "config.json" ) . unwrap ( ) ;
816+ let val = Sha256 :: digest ( std:: fs:: read ( & * new_downloaded_path) . unwrap ( ) ) ;
817+ assert_eq ! ( downloaded_path, new_downloaded_path) ;
818+ assert_eq ! (
819+ val[ ..] ,
820+ hex!( "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" )
821+ ) ;
822+
823+ // Here we prove the previous part was correctly resuming by purposefully corrupting the
824+ // file.
825+ let blob = std:: fs:: canonicalize ( & downloaded_path) . unwrap ( ) ;
826+ let mut file = std:: fs:: OpenOptions :: new ( ) . write ( true ) . open ( & blob) . unwrap ( ) ;
827+ let size = file. metadata ( ) . unwrap ( ) . len ( ) ;
828+ // Not random for consistent sha corruption
829+ let truncate: f32 = 0.5 ;
830+ let new_size = ( size as f32 * truncate) as u64 ;
831+ // Truncating
832+ file. set_len ( new_size) . unwrap ( ) ;
833+ // Corrupting by changing a single byte.
834+ file. seek ( SeekFrom :: Start ( new_size - 1 ) ) . unwrap ( ) ;
835+ file. write_all ( & [ 0 ] ) . unwrap ( ) ;
836+
837+ let mut blob_part = blob. clone ( ) ;
838+ blob_part. set_extension ( ".part" ) ;
839+ std:: fs:: rename ( blob, & blob_part) . unwrap ( ) ;
840+ std:: fs:: remove_file ( & downloaded_path) . unwrap ( ) ;
841+ let content = std:: fs:: read ( & * blob_part) . unwrap ( ) ;
842+ assert_eq ! ( content. len( ) as u64 , new_size) ;
843+ let val = Sha256 :: digest ( content) ;
844+ // We modified the sha.
845+ assert ! (
846+ val[ ..] != hex!( "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" )
847+ ) ;
848+ let new_downloaded_path = api. model ( model_id. clone ( ) ) . download ( "config.json" ) . unwrap ( ) ;
849+ let val = Sha256 :: digest ( std:: fs:: read ( & * new_downloaded_path) . unwrap ( ) ) ;
850+ println ! ( "Sha {val:#x}" ) ;
851+ assert_eq ! ( downloaded_path, new_downloaded_path) ;
852+ assert_eq ! (
853+ val[ ..] ,
854+ // Corrupted sha
855+ hex!( "32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5" )
856+ ) ;
857+ }
858+
759859 #[ test]
760860 fn simple_with_retries ( ) {
761861 let tmp = TempDir :: new ( ) ;
0 commit comments