{-# LANGUAGE BangPatterns, ScopedTypeVariables #-}
{- |
-- borrowed from snap-server. Check there periodically for updates.
-}
module Happstack.Server.Internal.TimeoutSocket where

import           Control.Applicative           (pure)
import           Control.Concurrent            (threadWaitWrite)
import           Control.Exception             as E (catch, throw)
import           Control.Monad                 (liftM, when)
import qualified Data.ByteString.Char8         as B
import qualified Data.ByteString.Lazy.Char8    as L
import qualified Data.ByteString.Lazy.Internal as L
import qualified Data.ByteString               as S
import           Network.Socket                (close)
import qualified Network.Socket.ByteString     as N
import qualified Happstack.Server.Internal.TimeoutManager as TM
import           Happstack.Server.Internal.TimeoutIO (TimeoutIO(..))
import           Network.Socket (Socket, ShutdownCmd(..), shutdown)
import           Network.Socket.SendFile (Iter(..), ByteCount, Offset, sendFileIterWith')
import           Network.Socket.ByteString (sendAll)
import           System.IO.Error (isDoesNotExistError, ioeGetErrorType)
import           System.IO.Unsafe (unsafeInterleaveIO)
import           GHC.IO.Exception (IOErrorType(InvalidArgument))

sPutLazyTickle :: TM.Handle -> Socket -> L.ByteString -> IO ()
sPutLazyTickle :: Handle -> Socket -> ByteString -> IO ()
sPutLazyTickle Handle
thandle Socket
sock ByteString
cs =
    do (StrictByteString -> IO () -> IO ())
-> IO () -> ByteString -> IO ()
forall a. (StrictByteString -> a -> a) -> a -> ByteString -> a
L.foldrChunks (\StrictByteString
c IO ()
rest -> Socket -> StrictByteString -> IO ()
sendAll Socket
sock StrictByteString
c IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
TM.tickle Handle
thandle IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ()
rest) (() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) ByteString
cs
{-# INLINE sPutLazyTickle #-}

sPutTickle :: TM.Handle -> Socket -> B.ByteString -> IO ()
sPutTickle :: Handle -> Socket -> StrictByteString -> IO ()
sPutTickle Handle
thandle Socket
sock StrictByteString
cs =
    do Socket -> StrictByteString -> IO ()
sendAll Socket
sock StrictByteString
cs
       Handle -> IO ()
TM.tickle Handle
thandle
       () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE sPutTickle #-}

sGet :: TM.Handle
     -> Socket
     -> IO (Maybe B.ByteString)
sGet :: Handle -> Socket -> IO (Maybe StrictByteString)
sGet Handle
handle Socket
socket =
  do s <- Socket -> Int -> IO StrictByteString
N.recv Socket
socket Int
65536
     TM.tickle handle
     if S.null s
       then pure Nothing
       else pure (Just s)

sGetContents :: TM.Handle
             -> Socket         -- ^ Connected socket
             -> IO L.ByteString  -- ^ Data received
sGetContents :: Handle -> Socket -> IO ByteString
sGetContents Handle
handle Socket
sock = IO ByteString
loop where
  loop :: IO ByteString
loop = IO ByteString -> IO ByteString
forall a. IO a -> IO a
unsafeInterleaveIO (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
    s <- Socket -> Int -> IO StrictByteString
N.recv Socket
sock Int
65536
    TM.tickle handle
    if S.null s
      then do
        -- 'InvalidArgument' is GHCs code for eNOTCONN (among other
        -- things). Sometimes the other end of socket is closed first
        -- and this end is already disconnected before we do
        -- 'shutdown'. Ignore this exception.
        shutdown sock ShutdownReceive `E.catch`
                    (\IOError
e -> Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (IOError -> Bool
isDoesNotExistError IOError
e Bool -> Bool -> Bool
|| IOError -> IOErrorType
ioeGetErrorType IOError
e IOErrorType -> IOErrorType -> Bool
forall a. Eq a => a -> a -> Bool
== IOErrorType
InvalidArgument)) (IOError -> IO ()
forall a e. (HasCallStack, Exception e) => e -> a
throw IOError
e))
        return L.Empty
      else L.Chunk s `liftM` loop


sendFileTickle :: TM.Handle -> Socket -> FilePath -> Offset -> ByteCount -> IO ()
sendFileTickle :: Handle -> Socket -> FilePath -> Offset -> Offset -> IO ()
sendFileTickle Handle
thandle Socket
outs FilePath
fp Offset
offset Offset
count =
    (IO Iter -> IO ())
-> Socket -> FilePath -> Offset -> Offset -> Offset -> IO ()
forall a.
(IO Iter -> IO a)
-> Socket -> FilePath -> Offset -> Offset -> Offset -> IO a
sendFileIterWith' (Handle -> IO Iter -> IO ()
iterTickle Handle
thandle) Socket
outs FilePath
fp Offset
65536 Offset
offset Offset
count

iterTickle :: TM.Handle -> IO Iter -> IO ()
iterTickle :: Handle -> IO Iter -> IO ()
iterTickle Handle
thandle =
    IO Iter -> IO ()
iterTickle'
    where
      iterTickle' :: (IO Iter -> IO ())
      iterTickle' :: IO Iter -> IO ()
iterTickle' IO Iter
iter =
          do r <- IO Iter
iter
             TM.tickle thandle
             case r of
               (Done Int64
_) ->
                      () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
               (WouldBlock Int64
_ Fd
fd IO Iter
cont) ->
                   do Fd -> IO ()
threadWaitWrite Fd
fd
                      IO Iter -> IO ()
iterTickle' IO Iter
cont
               (Sent Int64
_ IO Iter
cont) ->
                   do IO Iter -> IO ()
iterTickle' IO Iter
cont

timeoutSocketIO :: TM.Handle -> Socket -> TimeoutIO
timeoutSocketIO :: Handle -> Socket -> TimeoutIO
timeoutSocketIO Handle
handle Socket
socket =
    TimeoutIO { toHandle :: Handle
toHandle      = Handle
handle
              , toShutdown :: IO ()
toShutdown    = Socket -> IO ()
close Socket
socket
              , toPutLazy :: ByteString -> IO ()
toPutLazy     = Handle -> Socket -> ByteString -> IO ()
sPutLazyTickle Handle
handle Socket
socket
              , toGet :: IO (Maybe StrictByteString)
toGet         = Handle -> Socket -> IO (Maybe StrictByteString)
sGet           Handle
handle Socket
socket
              , toPut :: StrictByteString -> IO ()
toPut         = Handle -> Socket -> StrictByteString -> IO ()
sPutTickle     Handle
handle Socket
socket
              , toGetContents :: IO ByteString
toGetContents = Handle -> Socket -> IO ByteString
sGetContents   Handle
handle Socket
socket
              , toSendFile :: FilePath -> Offset -> Offset -> IO ()
toSendFile    = Handle -> Socket -> FilePath -> Offset -> Offset -> IO ()
sendFileTickle Handle
handle Socket
socket
              , toSecure :: Bool
toSecure      = Bool
False
              }