/src/download.jl

https://github.com/JuliaWeb/HTTP.jl · Julia · 147 lines · 121 code · 17 blank · 9 comment · 16 complexity · 7836b00d7697d0d11d3e109192b4d68c MD5 · raw file

  1. using .Pairs
  2. """
  3. safer_joinpath(basepart, parts...)
  4. A variation on `joinpath`, that is more resistant to directory traversal attacks.
  5. The parts to be joined (excluding the `basepart`),
  6. are not allowed to contain `..`, or begin with a `/`.
  7. If they do then this throws an `DomainError`.
  8. """
  9. function safer_joinpath(basepart, parts...)
  10. explain = "Possible directory traversal attack detected."
  11. for part in parts
  12. occursin("..", part) && throw(DomainError(part, "contains \"..\". $explain"))
  13. startswith(part, '/') && throw(DomainError(part, "begins with \"/\". $explain"))
  14. end
  15. joinpath(basepart, parts...)
  16. end
  17. function try_get_filename_from_headers(resp)
  18. content_disp = header(resp, "Content-Disposition")
  19. if content_disp != ""
  20. # extract out of Content-Disposition line
  21. # rough version of what is needed in https://github.com/JuliaWeb/HTTP.jl/issues/179
  22. filename_part = match(r"filename\s*=\s*(.*)", content_disp)
  23. if filename_part != nothing
  24. filename = filename_part[1]
  25. quoted_filename = match(r"\"(.*)\"", filename)
  26. if quoted_filename != nothing
  27. # It was in quotes, so it will be double escaped
  28. filename = unescape_string(quoted_filename[1])
  29. end
  30. return filename == "" ? nothing : filename
  31. end
  32. end
  33. return nothing
  34. end
  35. function try_get_filename_from_remote_path(target)
  36. target == "" && return nothing
  37. filename = basename(target)
  38. if filename == ""
  39. try_get_filename_from_remote_path(dirname(target))
  40. else
  41. filename
  42. end
  43. end
  44. determine_file(::Nothing, resp) = determine_file(tempdir(), resp)
  45. # ^ We want to the filename if possible because extension is useful for FileIO.jl
  46. function determine_file(path, resp)
  47. # get the name
  48. name = if isdir(path)
  49. # we have been given a path to a directory
  50. # got to to workout what file to put there
  51. filename = something(
  52. try_get_filename_from_headers(resp),
  53. try_get_filename_from_remote_path(resp.request.target),
  54. basename(tempname()) # fallback, basically a random string
  55. )
  56. safer_joinpath(path, filename)
  57. else
  58. # We have been given a full filepath
  59. path
  60. end
  61. # get the extension, if we are going to save it in encoded form.
  62. if header(resp, "Content-Encoding") == "gzip"
  63. name *= ".gz"
  64. end
  65. name
  66. end
  67. """
  68. download(url, [local_path], [headers]; update_period=1, kw...)
  69. Similar to `Base.download` this downloads a file, returning the filename.
  70. If the `local_path`:
  71. - is not provided, then it is saved in a temporary directory
  72. - if part to a directory is provided then it is saved into that directory
  73. - otherwise the local path is uses as the filename to save to.
  74. When saving into a directory, the filename is determined (where possible),
  75. from the rules of the HTTP.
  76. - `update_period` controls how often (in seconds) to report the progress.
  77. - set to `Inf` to disable reporting
  78. - `headers` specifies headers to be used for the HTTP GET request
  79. - any additional keyword args (`kw...`) are passed on to the HTTP request.
  80. """
  81. function download(url::AbstractString, local_path=nothing, headers=Header[]; update_period=1, kw...)
  82. format_progress(x) = round(x, digits=4)
  83. format_bytes(x) = !isfinite(x) ? "∞ B" : Base.format_bytes(x)
  84. format_seconds(x) = "$(round(x; digits=2)) s"
  85. format_bytes_per_second(x) = format_bytes(x) * "/s"
  86. @debug 1 "downloading $url"
  87. local file
  88. HTTP.open("GET", url, headers; kw...) do stream
  89. resp = startread(stream)
  90. eof(stream) && return # don't do anything for streams we can't read (yet)
  91. file = determine_file(local_path, resp)
  92. total_bytes = parse(Float64, header(resp, "Content-Length", "NaN"))
  93. downloaded_bytes = 0
  94. start_time = now()
  95. prev_time = now()
  96. function report_callback()
  97. prev_time = now()
  98. taken_time = (prev_time - start_time).value / 1000 # in seconds
  99. average_speed = downloaded_bytes / taken_time
  100. remaining_bytes = total_bytes - downloaded_bytes
  101. remaining_time = remaining_bytes / average_speed
  102. completion_progress = downloaded_bytes / total_bytes
  103. @info("Downloading",
  104. source=url,
  105. dest = file,
  106. progress = completion_progress |> format_progress,
  107. time_taken = taken_time |> format_seconds,
  108. time_remaining = remaining_time |> format_seconds,
  109. average_speed = average_speed |> format_bytes_per_second,
  110. downloaded = downloaded_bytes |> format_bytes,
  111. remaining = remaining_bytes |> format_bytes,
  112. total = total_bytes |> format_bytes,
  113. )
  114. end
  115. Base.open(file, "w") do fh
  116. while(!eof(stream))
  117. downloaded_bytes += write(fh, readavailable(stream))
  118. if !isinf(update_period)
  119. if now() - prev_time > Millisecond(round(1000update_period))
  120. report_callback()
  121. end
  122. end
  123. end
  124. end
  125. if !isinf(update_period)
  126. report_callback()
  127. end
  128. end
  129. file
  130. end