PageRenderTime 46ms CodeModel.GetById 18ms RepoModel.GetById 0ms app.codeStats 0ms

/src/Snap/Internal/Http/Server/TLS.hs

http://github.com/snapframework/snap-server
Haskell | 165 lines | 118 code | 29 blank | 18 comment | 2 complexity | 1fdb032777526ccdab2ea1236cdbbe43 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. {-# LANGUAGE CPP #-}
  2. {-# LANGUAGE DeriveDataTypeable #-}
  3. {-# LANGUAGE OverloadedStrings #-}
  4. {-# LANGUAGE ScopedTypeVariables #-}
  5. ------------------------------------------------------------------------------
  6. module Snap.Internal.Http.Server.TLS
  7. ( TLSException(..)
  8. , withTLS
  9. , bindHttps
  10. , httpsAcceptFunc
  11. , sendFileFunc
  12. ) where
  13. ------------------------------------------------------------------------------
  14. import Data.ByteString.Char8 (ByteString)
  15. import qualified Data.ByteString.Char8 as S
  16. import Data.Typeable (Typeable)
  17. import Network.Socket (Socket)
  18. #ifdef OPENSSL
  19. import Control.Exception (Exception, bracketOnError, finally, onException, throwIO)
  20. import Control.Monad (when)
  21. import Data.ByteString.Builder (byteString)
  22. import qualified Network.Socket as Socket
  23. import OpenSSL (withOpenSSL)
  24. import OpenSSL.Session (SSL, SSLContext)
  25. import qualified OpenSSL.Session as SSL
  26. import Prelude (Bool, FilePath, IO, Int, Maybe (..), Monad (..), Show, flip, fromIntegral, not, ($), ($!))
  27. import Snap.Internal.Http.Server.Address (getAddress)
  28. import Snap.Internal.Http.Server.Socket (acceptAndInitialize, bindSocket)
  29. import qualified System.IO.Streams as Streams
  30. import qualified System.IO.Streams.SSL as SStreams
  31. #else
  32. import Control.Exception (Exception, throwIO)
  33. import Prelude (Bool, FilePath, IO, Int, Show, id, ($))
  34. #endif
  35. ------------------------------------------------------------------------------
  36. import Snap.Internal.Http.Server.Types (AcceptFunc (..), SendFileHandler)
  37. ------------------------------------------------------------------------------
  38. data TLSException = TLSException S.ByteString
  39. deriving (Show, Typeable)
  40. instance Exception TLSException
  41. #ifndef OPENSSL
  42. type SSLContext = ()
  43. type SSL = ()
  44. ------------------------------------------------------------------------------
  45. sslNotSupportedException :: TLSException
  46. sslNotSupportedException = TLSException $ S.concat [
  47. "This version of snap-server was not built with SSL "
  48. , "support.\n"
  49. , "Please compile snap-server with -fopenssl to enable it."
  50. ]
  51. ------------------------------------------------------------------------------
  52. withTLS :: IO a -> IO a
  53. withTLS = id
  54. ------------------------------------------------------------------------------
  55. barf :: IO a
  56. barf = throwIO sslNotSupportedException
  57. ------------------------------------------------------------------------------
  58. bindHttps :: ByteString -> Int -> FilePath -> Bool -> FilePath
  59. -> IO (Socket, SSLContext)
  60. bindHttps _ _ _ _ _ = barf
  61. ------------------------------------------------------------------------------
  62. httpsAcceptFunc :: Socket -> SSLContext -> AcceptFunc
  63. httpsAcceptFunc _ _ = AcceptFunc $ \restore -> restore barf
  64. ------------------------------------------------------------------------------
  65. sendFileFunc :: SSL -> Socket -> SendFileHandler
  66. sendFileFunc _ _ _ _ _ _ _ = barf
  67. #else
  68. ------------------------------------------------------------------------------
  69. withTLS :: IO a -> IO a
  70. withTLS = withOpenSSL
  71. ------------------------------------------------------------------------------
  72. bindHttps :: ByteString
  73. -> Int
  74. -> FilePath
  75. -> Bool
  76. -> FilePath
  77. -> IO (Socket, SSLContext)
  78. bindHttps bindAddress bindPort cert chainCert key =
  79. withTLS $
  80. bracketOnError
  81. (bindSocket bindAddress bindPort)
  82. Socket.close
  83. $ \sock -> do
  84. ctx <- SSL.context
  85. SSL.contextSetPrivateKeyFile ctx key
  86. if chainCert
  87. then SSL.contextSetCertificateChainFile ctx cert
  88. else SSL.contextSetCertificateFile ctx cert
  89. certOK <- SSL.contextCheckPrivateKey ctx
  90. when (not certOK) $ do
  91. throwIO $ TLSException certificateError
  92. return (sock, ctx)
  93. where
  94. certificateError =
  95. "OpenSSL says that the certificate doesn't match the private key!"
  96. ------------------------------------------------------------------------------
  97. httpsAcceptFunc :: Socket
  98. -> SSLContext
  99. -> AcceptFunc
  100. httpsAcceptFunc boundSocket ctx =
  101. AcceptFunc $ \restore ->
  102. acceptAndInitialize boundSocket restore $ \(sock, remoteAddr) -> do
  103. localAddr <- Socket.getSocketName sock
  104. (localPort, localHost) <- getAddress localAddr
  105. (remotePort, remoteHost) <- getAddress remoteAddr
  106. ssl <- restore (SSL.connection ctx sock)
  107. restore (SSL.accept ssl) `onException` Socket.close sock
  108. (readEnd, writeEnd) <- SStreams.sslToStreams ssl
  109. let cleanup = (do Streams.write Nothing writeEnd
  110. SSL.shutdown ssl $! SSL.Unidirectional)
  111. `finally` Socket.close sock
  112. return $! ( sendFileFunc ssl
  113. , localHost
  114. , localPort
  115. , remoteHost
  116. , remotePort
  117. , readEnd
  118. , writeEnd
  119. , cleanup
  120. )
  121. ------------------------------------------------------------------------------
  122. sendFileFunc :: SSL -> SendFileHandler
  123. sendFileFunc ssl buffer builder fPath offset nbytes = do
  124. Streams.unsafeWithFileAsInputStartingAt (fromIntegral offset) fPath $ \fileInput0 -> do
  125. fileInput <- Streams.takeBytes (fromIntegral nbytes) fileInput0 >>=
  126. Streams.map byteString
  127. input <- Streams.fromList [builder] >>=
  128. flip Streams.appendInputStream fileInput
  129. output <- Streams.makeOutputStream sendChunk >>=
  130. Streams.unsafeBuilderStream (return buffer)
  131. Streams.supply input output
  132. Streams.write Nothing output
  133. where
  134. sendChunk (Just s) = SSL.write ssl s
  135. sendChunk Nothing = return $! ()
  136. #endif