{-# LANGUAGE UndecidableInstances, FieldSelectors #-}

module Conduit.Identity.Auth where

import Conduit.App.Has (Has, grab)
import Conduit.Features.Account.Types (UserID)
import Conduit.Identity.JWT (JWTInfo(..), jwtExpTime, mkClaims)
import Conduit.Utils ((.-))
import Data.Map.Strict as M (fromList)
import Data.Text (splitOn)
import Data.Time.Clock.POSIX (POSIXTime, getPOSIXTime)
import Network.HTTP.Types (status401)
import Relude.Extra (dup)
import Web.JWT (JWTClaimsSet(..), claims, decodeAndVerifySignature, encodeSigned, numericDate, stringOrURIToText)
import Web.Scotty.Trans (ActionT, header, json, status)

-- | The form of an authenticated user passed into any endpoint using 'withAuth'/'maybeWithAuth'.
data AuthedUser = AuthedUser
  { AuthedUser -> Text
authedToken  :: !Text
  , AuthedUser -> UserID
authedUserID :: !UserID
  } deriving (AuthedUser -> AuthedUser -> Bool
(AuthedUser -> AuthedUser -> Bool)
-> (AuthedUser -> AuthedUser -> Bool) -> Eq AuthedUser
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AuthedUser -> AuthedUser -> Bool
== :: AuthedUser -> AuthedUser -> Bool
$c/= :: AuthedUser -> AuthedUser -> Bool
/= :: AuthedUser -> AuthedUser -> Bool
Eq, Int -> AuthedUser -> ShowS
[AuthedUser] -> ShowS
AuthedUser -> String
(Int -> AuthedUser -> ShowS)
-> (AuthedUser -> String)
-> ([AuthedUser] -> ShowS)
-> Show AuthedUser
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AuthedUser -> ShowS
showsPrec :: Int -> AuthedUser -> ShowS
$cshow :: AuthedUser -> String
show :: AuthedUser -> String
$cshowList :: [AuthedUser] -> ShowS
showList :: [AuthedUser] -> ShowS
Show)

-- | An endpoint which requires user authentication.
-- 
-- > endpoint = get "/" $ withAuth \(user :: AuthedUser) -> do
-- >   ...
withAuth :: (MonadIO m, MonadReader c m, Has JWTInfo c m) => (AuthedUser -> ActionT m ()) -> ActionT m ()
withAuth :: forall (m :: * -> *) c.
(MonadIO m, MonadReader c m, Has JWTInfo c m) =>
(AuthedUser -> ActionT m ()) -> ActionT m ()
withAuth AuthedUser -> ActionT m ()
handler = (Maybe AuthedUser -> ActionT m ()) -> ActionT m ()
forall (m :: * -> *) c.
(MonadIO m, Has JWTInfo c m) =>
(Maybe AuthedUser -> ActionT m ()) -> ActionT m ()
maybeWithAuth ((Maybe AuthedUser -> ActionT m ()) -> ActionT m ())
-> (Maybe AuthedUser -> ActionT m ()) -> ActionT m ()
forall a b. (a -> b) -> a -> b
$ \case
  Just AuthedUser
user -> AuthedUser -> ActionT m ()
handler AuthedUser
user
  Maybe AuthedUser
Nothing -> Status -> ActionT m ()
forall (m :: * -> *). MonadIO m => Status -> ActionT m ()
status Status
status401 ActionT m () -> ActionT m () -> ActionT m ()
forall a b. ActionT m a -> ActionT m b -> ActionT m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Map Text Text -> ActionT m ()
forall a (m :: * -> *). (ToJSON a, MonadIO m) => a -> ActionT m ()
json Map Text Text
authErrRes

-- | An endpoint which requests, but does require, user authentication.
-- 
-- > endpoint = get "/" $ maybeWithAuth \(user :: Maybe AuthedUser) -> do
-- >   ...
maybeWithAuth :: (MonadIO m, Has JWTInfo c m) => (Maybe AuthedUser -> ActionT m ()) -> ActionT m ()
maybeWithAuth :: forall (m :: * -> *) c.
(MonadIO m, Has JWTInfo c m) =>
(Maybe AuthedUser -> ActionT m ()) -> ActionT m ()
maybeWithAuth Maybe AuthedUser -> ActionT m ()
handler = do
  Maybe Text
authHeader <- Text -> ActionT m (Maybe Text)
forall (m :: * -> *). Monad m => Text -> ActionT m (Maybe Text)
header Text
"Authorization"
  JWTInfo
jwtInfo <- m JWTInfo -> ActionT m JWTInfo
forall (m :: * -> *) a. Monad m => m a -> ActionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m JWTInfo -> ActionT m JWTInfo) -> m JWTInfo -> ActionT m JWTInfo
forall a b. (a -> b) -> a -> b
$ forall field container (m :: * -> *).
Has field container m =>
m field
grab @JWTInfo
  POSIXTime
currTime <- IO POSIXTime -> ActionT m POSIXTime
forall a. IO a -> ActionT m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO POSIXTime
getPOSIXTime
  Maybe AuthedUser -> ActionT m ()
handler (Maybe AuthedUser -> ActionT m ())
-> Maybe AuthedUser -> ActionT m ()
forall a b. (a -> b) -> a -> b
$ Maybe Text
authHeader Maybe Text -> (Text -> Maybe AuthedUser) -> Maybe AuthedUser
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= JWTInfo -> POSIXTime -> Text -> Maybe AuthedUser
forall a. ToText a => JWTInfo -> POSIXTime -> a -> Maybe AuthedUser
tryMakeAuthedUser JWTInfo
jwtInfo POSIXTime
currTime

-- | Some monad which can generate a JWT
class (Monad m) => AuthTokenGen m where
  mkAuthToken :: UserID -> m Text

instance (Monad m, MonadIO m, Has JWTInfo c m) => AuthTokenGen m where
  mkAuthToken :: UserID -> m Text
  mkAuthToken :: UserID -> m Text
mkAuthToken UserID
userID = do
    JWTInfo
jwtInfo <- forall field container (m :: * -> *).
Has field container m =>
m field
grab @JWTInfo
    POSIXTime
currTime <- IO POSIXTime -> m POSIXTime
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO POSIXTime
getPOSIXTime
    Text -> m Text
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> m Text) -> Text -> m Text
forall a b. (a -> b) -> a -> b
$ JWTInfo -> POSIXTime -> UserID -> Text
makeAuthTokenPure JWTInfo
jwtInfo POSIXTime
currTime UserID
userID

makeAuthTokenPure :: JWTInfo -> POSIXTime -> UserID -> Text
makeAuthTokenPure :: JWTInfo -> POSIXTime -> UserID -> Text
makeAuthTokenPure JWTInfo {EncodeSigner
VerifySigner
Seconds
$sel:jwtExpTime:JWTInfo :: JWTInfo -> Seconds
jwtEncodeSigner :: EncodeSigner
jwtVerifySigner :: VerifySigner
jwtExpTime :: Seconds
$sel:jwtEncodeSigner:JWTInfo :: JWTInfo -> EncodeSigner
$sel:jwtVerifySigner:JWTInfo :: JWTInfo -> VerifySigner
..} POSIXTime
currTime UserID
userID =
  let claims' :: JWTClaimsSet
claims' = POSIXTime -> Seconds -> UserID -> JWTClaimsSet
mkClaims POSIXTime
currTime Seconds
jwtExpTime UserID
userID
   in EncodeSigner -> JOSEHeader -> JWTClaimsSet -> Text
encodeSigned EncodeSigner
jwtEncodeSigner JOSEHeader
forall a. Monoid a => a
mempty JWTClaimsSet
claims'

tryMakeAuthedUser :: (ToText a) => JWTInfo -> POSIXTime -> a -> Maybe AuthedUser
tryMakeAuthedUser :: forall a. ToText a => JWTInfo -> POSIXTime -> a -> Maybe AuthedUser
tryMakeAuthedUser JWTInfo
jwtInfo POSIXTime
time a
authHeader  = a
authHeader
   a -> (a -> Text) -> Text
forall a b. a -> (a -> b) -> b
&  a -> Text
forall a. ToText a => a -> Text
toText
   Text -> (Text -> Maybe Text) -> Maybe Text
forall a b. a -> (a -> b) -> b
&  Text -> Maybe Text
extractToken
  Maybe Text
-> (Text -> (Text, Maybe UserID)) -> Maybe (Text, Maybe UserID)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Text -> (Text, Text)
forall a. a -> (a, a)
dup
   (Text -> (Text, Text))
-> ((Text, Text) -> (Text, Maybe UserID))
-> Text
-> (Text, Maybe UserID)
forall a b c. (a -> b) -> (b -> c) -> a -> c
.- (Text -> Maybe UserID) -> (Text, Text) -> (Text, Maybe UserID)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (JWTInfo -> POSIXTime -> Text -> Maybe UserID
tryGetSubjectFromJWT JWTInfo
jwtInfo POSIXTime
time)
  Maybe (Text, Maybe UserID)
-> ((Text, Maybe UserID) -> Maybe (Text, UserID))
-> Maybe (Text, UserID)
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Text, Maybe UserID) -> Maybe (Text, UserID)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => (Text, m a) -> m (Text, a)
sequence
  Maybe (Text, UserID)
-> ((Text, UserID) -> AuthedUser) -> Maybe AuthedUser
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Text -> UserID -> AuthedUser) -> (Text, UserID) -> AuthedUser
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Text -> UserID -> AuthedUser
AuthedUser

tryGetSubjectFromJWT :: JWTInfo -> POSIXTime -> Text -> Maybe UserID
tryGetSubjectFromJWT :: JWTInfo -> POSIXTime -> Text -> Maybe UserID
tryGetSubjectFromJWT JWTInfo
jwtInfo POSIXTime
time Text
token = Text
token
   Text -> (Text -> Maybe JWTClaimsSet) -> Maybe JWTClaimsSet
forall a b. a -> (a -> b) -> b
&  JWTInfo -> Text -> Maybe JWTClaimsSet
tryGetClaims JWTInfo
jwtInfo
  Maybe JWTClaimsSet
-> (JWTClaimsSet -> Maybe UserID) -> Maybe UserID
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \JWTClaimsSet
clms -> Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (JWTClaimsSet
clms.exp Maybe NumericDate -> Maybe NumericDate -> Bool
forall a. Ord a => a -> a -> Bool
> POSIXTime -> Maybe NumericDate
numericDate POSIXTime
time) Maybe () -> JWTClaimsSet -> Maybe JWTClaimsSet
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> JWTClaimsSet
clms
  Maybe JWTClaimsSet
-> (JWTClaimsSet -> Maybe StringOrURI) -> Maybe StringOrURI
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= JWTClaimsSet -> Maybe StringOrURI
sub
  Maybe StringOrURI -> (StringOrURI -> String) -> Maybe String
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> StringOrURI -> Text
stringOrURIToText
   (StringOrURI -> Text) -> (Text -> String) -> StringOrURI -> String
forall a b c. (a -> b) -> (b -> c) -> a -> c
.- Text -> String
forall a. ToString a => a -> String
toString
  Maybe String -> (String -> Maybe UserID) -> Maybe UserID
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String -> Maybe UserID
forall a. Read a => String -> Maybe a
readMaybe

tryGetClaims :: JWTInfo -> Text -> Maybe JWTClaimsSet
tryGetClaims :: JWTInfo -> Text -> Maybe JWTClaimsSet
tryGetClaims JWTInfo {EncodeSigner
VerifySigner
Seconds
$sel:jwtExpTime:JWTInfo :: JWTInfo -> Seconds
$sel:jwtEncodeSigner:JWTInfo :: JWTInfo -> EncodeSigner
$sel:jwtVerifySigner:JWTInfo :: JWTInfo -> VerifySigner
jwtEncodeSigner :: EncodeSigner
jwtVerifySigner :: VerifySigner
jwtExpTime :: Seconds
..} Text
token = Text
token
   Text
-> (Text -> Maybe (JWT VerifiedJWT)) -> Maybe (JWT VerifiedJWT)
forall a b. a -> (a -> b) -> b
&  VerifySigner -> Text -> Maybe (JWT VerifiedJWT)
decodeAndVerifySignature VerifySigner
jwtVerifySigner
  Maybe (JWT VerifiedJWT)
-> (JWT VerifiedJWT -> JWTClaimsSet) -> Maybe JWTClaimsSet
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> JWT VerifiedJWT -> JWTClaimsSet
forall r. JWT r -> JWTClaimsSet
claims

extractToken :: Text -> Maybe Text
extractToken :: Text -> Maybe Text
extractToken Text
str = case HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
splitOn Text
" " Text
str of
  [Text
"Token", Text
token] -> Text -> Maybe Text
forall a. a -> Maybe a
Just Text
token
  [Text]
_ -> Maybe Text
forall a. Maybe a
Nothing

authErrRes :: Map Text Text
authErrRes :: Map Text Text
authErrRes = [(Text, Text)] -> Map Text Text
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Text
"message", Text
"missing authorization credentials")]