{-# LANGUAGE UndecidableInstances, FieldSelectors #-}

module Conduit.Identity.Password
  ( HashedPassword(..)
  , UnsafePassword(..)
  , PasswordGen(..)
  , testPassword
  ) where

import Conduit.Utils ((.-))
import Conduit.Validation (Validation(..), NotBlank)
import Crypto.Error (CryptoFailable(..))
import Crypto.KDF.Argon2 (Options(..), Variant(..), defaultOptions)
import Crypto.KDF.Argon2 qualified as Argon
import Crypto.Random (MonadRandom (getRandomBytes))
import Data.Aeson (FromJSON)
import Data.ByteArray (Bytes, convert)
import Data.ByteString.Base64 (decodeBase64, encodeBase64)
import Data.Text (splitOn)
import Relude.Unsafe as Unsafe ((!!))

-- | A properly hashed password
newtype HashedPassword = HashedPassword { HashedPassword -> Text
getHashed :: Text }
  deriving newtype (HashedPassword -> HashedPassword -> Bool
(HashedPassword -> HashedPassword -> Bool)
-> (HashedPassword -> HashedPassword -> Bool) -> Eq HashedPassword
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HashedPassword -> HashedPassword -> Bool
== :: HashedPassword -> HashedPassword -> Bool
$c/= :: HashedPassword -> HashedPassword -> Bool
/= :: HashedPassword -> HashedPassword -> Bool
Eq)

-- | An unsafe plaintext password
newtype UnsafePassword = UnsafePassword { UnsafePassword -> Text
getUnsafe :: Text }
  deriving newtype (Maybe UnsafePassword
Value -> Parser [UnsafePassword]
Value -> Parser UnsafePassword
(Value -> Parser UnsafePassword)
-> (Value -> Parser [UnsafePassword])
-> Maybe UnsafePassword
-> FromJSON UnsafePassword
forall a.
(Value -> Parser a)
-> (Value -> Parser [a]) -> Maybe a -> FromJSON a
$cparseJSON :: Value -> Parser UnsafePassword
parseJSON :: Value -> Parser UnsafePassword
$cparseJSONList :: Value -> Parser [UnsafePassword]
parseJSONList :: Value -> Parser [UnsafePassword]
$comittedField :: Maybe UnsafePassword
omittedField :: Maybe UnsafePassword
FromJSON)

instance Validation NotBlank UnsafePassword where
  validate :: UnsafePassword -> Bool
validate = forall {k} (property :: k) on. Validation property on => on -> Bool
forall property on. Validation property on => on -> Bool
validate @NotBlank (Text -> Bool)
-> (UnsafePassword -> Text) -> UnsafePassword -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnsafePassword -> Text
getUnsafe
  errMsg :: String
errMsg = forall {k} (property :: k) on. Validation property on => String
forall property on. Validation property on => String
errMsg @NotBlank @Text

-- | Some monad which can properly hash an 'UnsafePassword'
class (Monad m) => PasswordGen m where
  hashPassword :: UnsafePassword -> m HashedPassword

instance (Monad m, MonadIO m) => PasswordGen m where
  hashPassword :: UnsafePassword -> m HashedPassword
  hashPassword :: UnsafePassword -> m HashedPassword
hashPassword = UnsafePassword -> ByteString -> HashedPassword
hashPasswordWithSalt (UnsafePassword -> ByteString -> HashedPassword)
-> ((ByteString -> HashedPassword) -> m HashedPassword)
-> UnsafePassword
-> m HashedPassword
forall a b c. (a -> b) -> (b -> c) -> a -> c
.- ((ByteString -> HashedPassword) -> m ByteString -> m HashedPassword
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m ByteString
forall (m :: * -> *). MonadIO m => m ByteString
newSalt)

-- | Options for Cryptonite's Argon2 algo; strikes a fair balance of performance and security (I think).
argonOptions :: Options
argonOptions :: Options
argonOptions = Options
defaultOptions
  { variant :: Variant
variant = Variant
Argon2id
  , parallelism :: Parallelism
parallelism = Parallelism
2
  , iterations :: Parallelism
iterations  = Parallelism
2
  , memory :: Parallelism
memory = Parallelism
65536 
  }

hashStrParams :: Text
hashStrParams :: Text
hashStrParams = Text
"$argon2id$v=13$m=" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Parallelism -> Text
forall b a. (Show a, IsString b) => a -> b
show Parallelism
m Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
",t=" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Parallelism -> Text
forall b a. (Show a, IsString b) => a -> b
show Parallelism
t Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
",p=" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Parallelism -> Text
forall b a. (Show a, IsString b) => a -> b
show Parallelism
p Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"$"
  where a :: Options
a = Options
argonOptions; m :: Parallelism
m = Options
a.memory; t :: Parallelism
t = Options
a.iterations; p :: Parallelism
p = Options
a.parallelism;

newSalt :: (MonadIO m) => m ByteString
newSalt :: forall (m :: * -> *). MonadIO m => m ByteString
newSalt = IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
16

extractSalt :: HashedPassword -> Maybe ByteString
extractSalt :: HashedPassword -> Maybe ByteString
extractSalt (HashedPassword Text
hash') = Either Text ByteString -> Maybe ByteString
forall l r. Either l r -> Maybe r
rightToMaybe (Either Text ByteString -> Maybe ByteString)
-> (Text -> Either Text ByteString) -> Text -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either Text ByteString
decodeBase64 (ByteString -> Either Text ByteString)
-> (Text -> ByteString) -> Text -> Either Text ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
forall a b. ConvertUtf8 a b => a -> b
encodeUtf8 (Text -> Maybe ByteString) -> Text -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
splitOn Text
"$" Text
hash' [Text] -> Int -> Text
forall a. HasCallStack => [a] -> Int -> a
Unsafe.!! Int
4

text2bytes :: Text -> Bytes
text2bytes :: Text -> Bytes
text2bytes = ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (ByteString -> Bytes) -> (Text -> ByteString) -> Text -> Bytes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. ConvertUtf8 a b => a -> b
encodeUtf8 @_ @ByteString

hashPasswordWithSalt :: UnsafePassword -> ByteString -> HashedPassword
hashPasswordWithSalt :: UnsafePassword -> ByteString -> HashedPassword
hashPasswordWithSalt (UnsafePassword Text
password) ByteString
salt =
  let digest :: Bytes
digest = forall password salt out.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray out) =>
Options -> password -> salt -> Int -> CryptoFailable out
Argon.hash @_ @Bytes @Bytes Options
argonOptions (Text -> Bytes
text2bytes Text
password) (ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ByteString
salt) Int
32 CryptoFailable Bytes -> (CryptoFailable Bytes -> Bytes) -> Bytes
forall a b. a -> (a -> b) -> b
& \case
        CryptoPassed Bytes
digest' -> Bytes
digest'
        CryptoFailed CryptoError
err -> Text -> Bytes
forall a t. (HasCallStack, IsText t) => t -> a
error (Text -> Bytes) -> Text -> Bytes
forall a b. (a -> b) -> a -> b
$ CryptoError -> Text
forall b a. (Show a, IsString b) => a -> b
show CryptoError
err -- I think that all of the CryptoErrors are deeper rooted issues and should just fail-fast
   in ByteString -> ByteString -> HashedPassword
mkHashedPassword (Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert Bytes
digest) ByteString
salt

mkHashedPassword :: ByteString -> ByteString -> HashedPassword
mkHashedPassword :: ByteString -> ByteString -> HashedPassword
mkHashedPassword ByteString
digest ByteString
salt = Text -> HashedPassword
HashedPassword (Text -> HashedPassword) -> Text -> HashedPassword
forall a b. (a -> b) -> a -> b
$ Text
hashStrParams Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
salt' Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"$" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
digest'
  where digest' :: Text
digest' = ByteString -> Text
encodeBase64 ByteString
digest; salt' :: Text
salt' = ByteString -> Text
encodeBase64 ByteString
salt;

-- | Validates a plaintext password against its hashed potential counterpart.
testPassword :: UnsafePassword -> HashedPassword -> Bool
testPassword :: UnsafePassword -> HashedPassword -> Bool
testPassword UnsafePassword
password HashedPassword
hashed = do
  let salt :: Maybe ByteString
salt = HashedPassword -> Maybe ByteString
extractSalt HashedPassword
hashed
   in Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (UnsafePassword -> ByteString -> HashedPassword
hashPasswordWithSalt UnsafePassword
password (ByteString -> HashedPassword)
-> (HashedPassword -> Bool) -> ByteString -> Bool
forall a b c. (a -> b) -> (b -> c) -> a -> c
.- (HashedPassword -> HashedPassword -> Bool
forall a. Eq a => a -> a -> Bool
== HashedPassword
hashed)) Maybe ByteString
salt