/src/download.jl

https://github.com/fredrikekre/HTTP.jl · Julia · 142 lines · 116 code · 18 blank · 8 comment · 13 complexity · cab981109378b7e7a0c043febaef04b0 MD5 · raw file

  1. using .Pairs
  2. """
  3. safer_joinpath(basepart, parts...)
  4. A variation on `joinpath`, that is more resistant to directory traversal attacks.
  5. The parts to be joined (excluding the `basepart`),
  6. are not allowed to contain `..`, or begin with a `/`.
  7. If they do then this throws an `DomainError`.
  8. """
  9. function safer_joinpath(basepart, parts...)
  10. explain = "Possible directory traversal attack detected."
  11. for part in parts
  12. occursin("..", part) && throw(DomainError(part, "contains \"..\". $explain"))
  13. startswith(part, '/') && throw(DomainError(part, "begins with \"/\". $explain"))
  14. end
  15. joinpath(basepart, parts...)
  16. end
  17. function try_get_filename_from_headers(headers)
  18. content_disp = getkv(headers, "Content-Disposition")
  19. if content_disp != nothing
  20. # extract out of Content-Disposition line
  21. # rough version of what is needed in https://github.com/JuliaWeb/HTTP.jl/issues/179
  22. filename_part = match(r"filename\s*=\s*(.*)", content_disp)
  23. if filename_part != nothing
  24. filename = filename_part[1]
  25. quoted_filename = match(r"\"(.*)\"", filename)
  26. if quoted_filename != nothing
  27. # It was in quotes, so it will be double escaped
  28. filename = unescape_string(quoted_filename[1])
  29. end
  30. return filename
  31. end
  32. end
  33. return nothing
  34. end
  35. function try_get_filename_from_remote_path(target)
  36. target == "" && return nothing
  37. filename = basename(target)
  38. if filename == ""
  39. try_get_filename_from_remote_path(dirname(target))
  40. else
  41. filename
  42. end
  43. end
  44. determine_file(::Nothing, resp) = determine_file(tempdir(), resp)
  45. # ^ We want to the filename if possible because extension is useful for FileIO.jl
  46. function determine_file(path, resp)
  47. # get the name
  48. name = if isdir(path)
  49. # got to to workout what file to put there
  50. filename = something(
  51. try_get_filename_from_headers(resp.headers),
  52. try_get_filename_from_remote_path(resp.request.target),
  53. basename(tempname()) # fallback, basically a random string
  54. )
  55. safer_joinpath(path, filename)
  56. else
  57. # It is a file, we are done.
  58. path
  59. end
  60. # get the extension, if we are going to save it in encoded form.
  61. if header(resp, "Content-Encoding") == "gzip"
  62. name *= ".gz"
  63. end
  64. name
  65. end
  66. """
  67. download(url, [local_path], [headers]; update_period=1, kw...)
  68. Similar to `Base.download` this downloads a file, returning the filename.
  69. If the `local_path`:
  70. - is not provided, then it is saved in a temporary directory
  71. - if part to a directory is provided then it is saved into that directory
  72. - otherwise the local path is uses as the filename to save to.
  73. When saving into a directory, the filename is determined (where possible),
  74. from the rules of the HTTP.
  75. - `update_period` controls how often (in seconds) to report the progress.
  76. - set to `Inf` to disable reporting
  77. - `headers` specifies headers to be used for the HTTP GET request
  78. - any additional keyword args (`kw...`) are passed on to the HTTP request.
  79. """
  80. function download(url::AbstractString, local_path=nothing, headers=Header[]; update_period=1, kw...)
  81. format_progress(x) = round(x, digits=4)
  82. format_bytes(x) = !isfinite(x) ? "∞ B" : Base.format_bytes(x)
  83. format_seconds(x) = "$(round(x; digits=2)) s"
  84. format_bytes_per_second(x) = format_bytes(x) * "/s"
  85. @debug 1 "downloading $url"
  86. local file
  87. HTTP.open("GET", url, headers; kw...) do stream
  88. resp = startread(stream)
  89. file = determine_file(local_path, resp)
  90. total_bytes = parse(Float64, getkv(resp.headers, "Content-Length", "NaN"))
  91. downloaded_bytes = 0
  92. start_time = now()
  93. prev_time = now()
  94. function report_callback()
  95. prev_time = now()
  96. taken_time = (prev_time - start_time).value / 1000 # in seconds
  97. average_speed = downloaded_bytes / taken_time
  98. remaining_bytes = total_bytes - downloaded_bytes
  99. remaining_time = remaining_bytes / average_speed
  100. completion_progress = downloaded_bytes / total_bytes
  101. @info("Downloading",
  102. source=url,
  103. dest = file,
  104. progress = completion_progress |> format_progress,
  105. time_taken = taken_time |> format_seconds,
  106. time_remaining = remaining_time |> format_seconds,
  107. average_speed = average_speed |> format_bytes_per_second,
  108. downloaded = downloaded_bytes |> format_bytes,
  109. remaining = remaining_bytes |> format_bytes,
  110. total = total_bytes |> format_bytes,
  111. )
  112. end
  113. Base.open(file, "w") do fh
  114. while(!eof(stream))
  115. downloaded_bytes += write(fh, readavailable(stream))
  116. if now() - prev_time > Millisecond(1000update_period)
  117. report_callback()
  118. end
  119. end
  120. end
  121. report_callback()
  122. end
  123. file
  124. end