{-# LANGUAGE AllowAmbiguousTypes, UndecidableInstances, TemplateHaskell #-}

module Conduit.DB.Core where

import Conduit.App.Has (Has, grab)
import Conduit.Errors (FeatureError(..))
import Conduit.Utils ((.-))
import Data.List (stripPrefix)
import Database.Esqueleto.Experimental (ConnectionPool, Key, SqlPersistT, fromSqlKey, runSqlPool, toSqlKey)
import Database.PostgreSQL.Simple (ExecStatus(..), SqlError(..))
import Language.Haskell.TH
import UnliftIO (MonadUnliftIO, catch)

-- | The 'ConnectionPool' managing the DB connections.
newtype DBPool = DBPool { DBPool -> ConnectionPool
unPool :: ConnectionPool }

-- | Some monad which can run an Esqueleto SQL query/stmt.
class (Monad m) => MonadDB m where
  runDB :: SqlPersistT m a -> m (Either DBError a)

instance (Monad m, MonadUnliftIO m, Has DBPool c m) => MonadDB m where
  runDB :: SqlPersistT m a -> m (Either DBError a)
  runDB :: forall a. SqlPersistT m a -> m (Either DBError a)
runDB SqlPersistT m a
fn = forall field container (m :: * -> *).
Has field container m =>
m field
grab @DBPool m DBPool -> (DBPool -> ConnectionPool) -> m ConnectionPool
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (.unPool) m ConnectionPool
-> (ConnectionPool -> m (Either DBError a)) -> m (Either DBError a)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SqlPersistT m a -> ConnectionPool -> m a
forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> Pool backend -> m a
runSqlPool SqlPersistT m a
fn (ConnectionPool -> m a)
-> (m a -> m (Either DBError a))
-> ConnectionPool
-> m (Either DBError a)
forall a b c. (a -> b) -> (b -> c) -> a -> c
.- m a -> m (Either DBError a)
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> m (Either DBError a)
catchSqlError

-- | An abstraction to allow for easy conversion between Esqueleto entity Keys and Conduit's own ID datatypes.
--   'deriveSqlKey' can automagically create instances for this class.
--   
-- > newtype TableID = TableID { unID :: Int64 }
-- > <define some persist table Table>
-- > $(deriveSqlKey ''Table ''TableID)
class SqlKey t id | t -> id, id -> t where
  sqlKey2ID :: Key t -> id
  id2sqlKey :: id -> Key t

-- | just tried this for fun, very quickly realized I am nowhere near smart enough to be doing something like this.
--   this took me over an hour.
--   help.
-- 
-- see 'SqlKey' for usage
deriveSqlKey :: Name -> Name -> Q [Dec]
deriveSqlKey :: Name -> Name -> Q [Dec]
deriveSqlKey Name
tableName Name
keyName = do
  Name
conName <- Name -> Q Name
getConName Name
keyName

  [d|
    instance SqlKey $(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
conT Name
tableName) $(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
conT Name
keyName) where
      sqlKey2ID = $(Exp -> Q Exp
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
ConE Name
conName) . fromSqlKey
      id2sqlKey $(Name -> [Q Pat] -> Q Pat
forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
conName [Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP ([Char] -> Name
mkName [Char]
"id'")]) = toSqlKey id'
    |]

-- | Attempts to map relevant SqlErrors to more processable types.
data DBError
  = SomeDBError        Text -- ^ Catch-all for irrelevant/unkown SqlErrors
  | UniquenessError    Text -- ^ Unique constraint violation err, holds name of violated column
  | AuthorizationError Text -- ^ See 'authorizationSqlError'
  | NotFoundError           -- ^ See 'resourceNotFoundSqlError'
  deriving (Int -> DBError -> ShowS
[DBError] -> ShowS
DBError -> [Char]
(Int -> DBError -> ShowS)
-> (DBError -> [Char]) -> ([DBError] -> ShowS) -> Show DBError
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DBError -> ShowS
showsPrec :: Int -> DBError -> ShowS
$cshow :: DBError -> [Char]
show :: DBError -> [Char]
$cshowList :: [DBError] -> ShowS
showList :: [DBError] -> ShowS
Show, DBError -> DBError -> Bool
(DBError -> DBError -> Bool)
-> (DBError -> DBError -> Bool) -> Eq DBError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DBError -> DBError -> Bool
== :: DBError -> DBError -> Bool
$c/= :: DBError -> DBError -> Bool
/= :: DBError -> DBError -> Bool
Eq, ReadPrec [DBError]
ReadPrec DBError
Int -> ReadS DBError
ReadS [DBError]
(Int -> ReadS DBError)
-> ReadS [DBError]
-> ReadPrec DBError
-> ReadPrec [DBError]
-> Read DBError
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS DBError
readsPrec :: Int -> ReadS DBError
$creadList :: ReadS [DBError]
readList :: ReadS [DBError]
$creadPrec :: ReadPrec DBError
readPrec :: ReadPrec DBError
$creadListPrec :: ReadPrec [DBError]
readListPrec :: ReadPrec [DBError]
Read)

-- | Maps both sides of a potentially errored SQL query.
-- 
-- > newtype Name = Name Text
-- >
-- > mkNames :: [Value Text] -> [Name]
-- > mkNames = map (\(Value name) -> Name name)
-- >
-- > result :: (Monad m, MonadDB m, MonadUnliftIO m) => m (Either ??? [Name])
-- > result = mapDBResult mkNames <$> runDBError do
-- >   select $ do
-- >     u <- from table @User
-- >     pure u.name
mapDBResult :: (FeatureError e) => (a -> b) -> Either DBError a -> Either e b
mapDBResult :: forall e a b.
FeatureError e =>
(a -> b) -> Either DBError a -> Either e b
mapDBResult = (DBError -> e) -> (a -> b) -> Either DBError a -> Either e b
forall a b c d. (a -> b) -> (c -> d) -> Either a c -> Either b d
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap DBError -> e
forall e. FeatureError e => DBError -> e
handleDBError

-- | Maps both sides of a potentially errored, potentially not-found SQL query.
mapMaybeDBResult :: (FeatureError e) => e -> (a -> b) -> Either DBError (Maybe a) -> Either e b
mapMaybeDBResult :: forall e a b.
FeatureError e =>
e -> (a -> b) -> Either DBError (Maybe a) -> Either e b
mapMaybeDBResult e
err a -> b
f Either DBError (Maybe a)
dbResult = do
  Maybe a
result <- DBError -> e
forall e. FeatureError e => DBError -> e
handleDBError (DBError -> e) -> Either DBError (Maybe a) -> Either e (Maybe a)
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
`first` Either DBError (Maybe a)
dbResult
  a -> b
f (a -> b) -> Either e a -> Either e b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> e -> Maybe a -> Either e a
forall l r. l -> Maybe r -> Either l r
maybeToRight e
err Maybe a
result

-- | Maps the error of a potentially errored SQL query/stmt.
mapDBError :: (FeatureError e) => Either DBError a -> Either e a
mapDBError :: forall e a. FeatureError e => Either DBError a -> Either e a
mapDBError = (DBError -> e) -> Either DBError a -> Either e a
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DBError -> e
forall e. FeatureError e => DBError -> e
handleDBError

-- | Maps the error of a potentially errored SQL query/stmt & Assurances that the result is > 0.
--   Intended for use with something like Esqueleto's @insertCount@ or @deleteCount@.
expectDBNonZero :: (FeatureError e, Num cnt, Ord cnt) => e -> Either DBError cnt -> Either e ()
expectDBNonZero :: forall e cnt.
(FeatureError e, Num cnt, Ord cnt) =>
e -> Either DBError cnt -> Either e ()
expectDBNonZero e
err Either DBError cnt
dbResult = do
  cnt
result <- DBError -> e
forall e. FeatureError e => DBError -> e
handleDBError (DBError -> e) -> Either DBError cnt -> Either e cnt
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
`first` Either DBError cnt
dbResult
  Bool -> Either e () -> Either e ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (cnt
result cnt -> cnt -> Bool
forall a. Eq a => a -> a -> Bool
== cnt
0) (Either e () -> Either e ()) -> Either e () -> Either e ()
forall a b. (a -> b) -> a -> b
$
    e -> Either e ()
forall a b. a -> Either a b
Left e
err

-- | A custom SqlError w/ code 45401
authorizationSqlError :: (Show e) => e -> SqlError
authorizationSqlError :: forall e. Show e => e -> SqlError
authorizationSqlError e
err = SqlError
defaultSqlErr
  { sqlState :: ByteString
sqlState = ByteString
"45401"
  , sqlErrorMsg :: ByteString
sqlErrorMsg = ByteString
"Authorization error"
  , sqlErrorDetail :: ByteString
sqlErrorDetail = e -> ByteString
forall b a. (Show a, IsString b) => a -> b
show e
err
  }

-- | A custom SqlError w/ code 45404
resourceNotFoundSqlError :: SqlError
resourceNotFoundSqlError :: SqlError
resourceNotFoundSqlError = SqlError
defaultSqlErr
  { sqlState :: ByteString
sqlState = ByteString
"45404"
  , sqlErrorMsg :: ByteString
sqlErrorMsg = ByteString
"Resource not found"
  }

-- | (Internal) Catches & maps any SqlErrors to 'DBError's
catchSqlError :: (MonadUnliftIO m) => m a -> m (Either DBError a)
catchSqlError :: forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> m (Either DBError a)
catchSqlError m a
stmt = forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
catch @_ @SqlError
  (a -> Either DBError a
forall a b. b -> Either a b
Right (a -> Either DBError a) -> m a -> m (Either DBError a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m a
stmt)
  (Either DBError a -> m (Either DBError a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either DBError a -> m (Either DBError a))
-> (SqlError -> Either DBError a)
-> SqlError
-> m (Either DBError a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DBError -> Either DBError a
forall a b. a -> Either a b
Left (DBError -> Either DBError a)
-> (SqlError -> DBError) -> SqlError -> Either DBError a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlError -> DBError
mapSqlError)

mapSqlError :: SqlError -> DBError
mapSqlError :: SqlError -> DBError
mapSqlError SqlError
err
  | SqlError
err.sqlState ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"23505" = Text -> DBError
UniquenessError (Text -> DBError) -> Text -> DBError
forall a b. (a -> b) -> a -> b
$ SqlError -> Text
extractUniquenessViolation SqlError
err
  | SqlError
err.sqlState ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"45401" = Text -> DBError
AuthorizationError (Text -> DBError) -> Text -> DBError
forall a b. (a -> b) -> a -> b
$ ByteString -> Text
forall a b. ConvertUtf8 a b => b -> a
decodeUtf8 SqlError
err.sqlErrorDetail
  | SqlError
err.sqlState ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"45404" = DBError
NotFoundError
  | Bool
otherwise = Text -> DBError
SomeDBError (Text -> DBError) -> Text -> DBError
forall a b. (a -> b) -> a -> b
$ SqlError -> Text
forall b a. (Show a, IsString b) => a -> b
show SqlError
err

-- SqlError {sqlState = "23505", sqlExecStatus = FatalError, sqlErrorMsg = "duplicate key value violates unique constraint \"unique_username\"", sqlErrorDetail = "Key (username)=(username) already exists.", sqlErrorHint = ""}

-- | (Internal) Extracts the name of the column whose uniqueness was violated
extractUniquenessViolation :: SqlError -> Text
extractUniquenessViolation :: SqlError -> Text
extractUniquenessViolation = [Char] -> Text
forall a. ToText a => a -> Text
toText ([Char] -> Text) -> (SqlError -> [Char]) -> SqlError -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
extractViolatedColName ShowS -> (SqlError -> [Char]) -> SqlError -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Char]
forall a b. ConvertUtf8 a b => b -> a
decodeUtf8 (ByteString -> [Char])
-> (SqlError -> ByteString) -> SqlError -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlError -> ByteString
sqlErrorDetail
  where extractViolatedColName :: ShowS
extractViolatedColName = [Char] -> Maybe [Char]
extractKeyField ([Char] -> Maybe [Char]) -> (Maybe [Char] -> [Char]) -> ShowS
forall a b c. (a -> b) -> (b -> c) -> a -> c
.- [Char] -> Maybe [Char] -> [Char]
forall a. a -> Maybe a -> a
fromMaybe (Text -> [Char]
forall a t. (HasCallStack, IsText t) => t -> a
error Text
"")

extractKeyField :: String -> Maybe String
extractKeyField :: [Char] -> Maybe [Char]
extractKeyField [Char]
str = do
  [Char]
rest <- [Char] -> [Char] -> Maybe [Char]
forall a. Eq a => [a] -> [a] -> Maybe [a]
stripPrefix [Char]
"Key (" [Char]
str
  let ([Char]
keyField, [Char]
_) = (Char -> Bool) -> [Char] -> ([Char], [Char])
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
')') [Char]
rest
  [Char] -> Maybe [Char]
forall a. a -> Maybe a
Just [Char]
keyField

-- | (Internal) SqlError with irrelevant/unused fields pre-filled out
defaultSqlErr :: SqlError
defaultSqlErr :: SqlError
defaultSqlErr = SqlError
  { sqlState :: ByteString
sqlState = Text -> ByteString
forall a t. (HasCallStack, IsText t) => t -> a
error Text
"fill out sqlState"
  , sqlExecStatus :: ExecStatus
sqlExecStatus = ExecStatus
FatalError
  , sqlErrorMsg :: ByteString
sqlErrorMsg = Text -> ByteString
forall a t. (HasCallStack, IsText t) => t -> a
error Text
"fill out sqlErrorMsg"
  , sqlErrorDetail :: ByteString
sqlErrorDetail = ByteString
""
  , sqlErrorHint :: ByteString
sqlErrorHint = ByteString
""
  }

-- | (Internal) Gets the constructor name of some newtype
getConName :: Name -> Q Name
getConName :: Name -> Q Name
getConName Name
typeName = do
  (TyConI Dec
tyCon) <- Name -> Q Info
reify Name
typeName
  
  case Dec
tyCon of
    NewtypeD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ (RecC    Name
name [VarBangType]
_) [DerivClause]
_ -> Name -> Q Name
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
name
    NewtypeD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ (NormalC Name
name [BangType]
_) [DerivClause]
_ -> Name -> Q Name
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
name
    NewtypeD {} -> [Char] -> Q Name
forall a. [Char] -> Q a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Newtype constructor not in expected format"
    Dec
_ -> [Char] -> Q Name
forall a. [Char] -> Q a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Expected a newtype"