Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ library
, timeit >= 2.0 && < 2.1
, unordered-containers >= 0.2.8 && < 0.3
, unix-compat >= 0.5.4 && < 0.8
, vault >= 0.3.1.5 && < 0.4
, vector >= 0.11 && < 0.14
, wai >= 3.2.1 && < 3.3
, wai-cors >= 0.2.5 && < 0.3
Expand Down
30 changes: 16 additions & 14 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,16 @@ postgrest :: LogLevel -> AppState.AppState -> IO () -> Wai.Application
postgrest logLevel appState connWorker =
traceHeaderMiddleware appState .
Cors.middleware appState .
Auth.middleware appState .
Logger.middleware logLevel Auth.getRole $
-- fromJust can be used, because the auth middleware will **always** add
-- some AuthResult to the vault.
\req respond -> case fromJust $ Auth.getResult req of
Left err -> respond $ Error.errorResponseFor err
Right authResult -> do
Logger.middleware logLevel $
\req respond -> do
appConf <- AppState.getConfig appState -- the config must be read again because it can reload
maybeSchemaCache <- AppState.getSchemaCache appState
pgVer <- AppState.getPgVersion appState

let
eitherResponse :: IO (Either Error Wai.Response)
eitherResponse =
runExceptT $ postgrestResponse appState appConf maybeSchemaCache pgVer authResult req
runExceptT $ postgrestResponse appState appConf maybeSchemaCache pgVer req

response <- either Error.errorResponseFor identity <$> eitherResponse
-- Launch the connWorker when the connection is down. The postgrest
Expand All @@ -130,10 +125,9 @@ postgrestResponse
-> AppConfig
-> Maybe SchemaCache
-> PgVersion
-> AuthResult
-> Wai.Request
-> Handler IO Wai.Response
postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@AuthResult{..} req = do
postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer req = do
sCache <-
case maybeSchemaCache of
Just sCache ->
Expand All @@ -143,13 +137,20 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@

body <- lift $ Wai.strictRequestBody req

let jwtTime = if configServerTimingEnabled then Auth.getJwtDur req else Nothing
timezones = dbTimezones sCache
prefs = ApiRequest.userPreferences conf req timezones
-- API-REQUEST/PARSE STAGE
let prefs = ApiRequest.userPreferences conf req (dbTimezones sCache)

(parseTime, apiReq@ApiRequest{..}) <- withTiming $ liftEither . mapLeft Error.ApiRequestError $ ApiRequest.userApiRequest conf prefs req body
(planTime, plan) <- withTiming $ liftEither $ Plan.actionPlan iAction conf apiReq sCache

-- JWT/AUTH STAGE
(jwtTime, authResult@AuthResult{..}) <- withTiming $ do
eitherAuthResult <- liftIO $ Auth.getAuthResult appState apiReq
liftEither eitherAuthResult

-- PLAN STAGE
(planTime, plan) <- withTiming $ liftEither $ Plan.actionPlan iAction conf apiReq sCache

-- QUERY/TRANSACTION STAGE
let query = Query.query conf authResult apiReq plan sCache pgVer
logSQL = lift . AppState.getObserver appState . DBQuery (Query.getSQLQuery query)

Expand All @@ -162,6 +163,7 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@
when (configLogQuery /= LogQueryDisabled) $ whenLeft eitherResp $ logSQL . Error.status
liftEither eitherResp >>= liftEither

-- RESPONSE STAGE
(respTime, resp) <- withTiming $ do
let response = Response.actionResponse queryResult apiReq (T.decodeUtf8 prettyVersion, docsVersion) conf sCache iSchema iNegotiatedByProfile
when (configLogQuery /= LogQueryDisabled) $ logSQL $ either Error.status Response.pgrstStatus response
Expand Down
82 changes: 22 additions & 60 deletions src/PostgREST/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@ very simple authentication system inside the PostgreSQL database.
-}
{-# LANGUAGE RecordWildCards #-}
module PostgREST.Auth
( getResult
, getJwtDur
, getRole
, middleware
) where
( getAuthResult )
where

import qualified Data.Aeson as JSON
import qualified Data.Aeson.Key as K
Expand All @@ -25,24 +22,22 @@ import qualified Data.Aeson.Types as JSON
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.ByteString.Lazy.Char8 as LBS
import qualified Data.CaseInsensitive as CI
import qualified Data.Scientific as Sci
import qualified Data.Text as T
import qualified Data.Vault.Lazy as Vault
import qualified Data.Vector as V
import qualified Jose.Jwk as JWT
import qualified Jose.Jwt as JWT
import qualified Network.HTTP.Types.Header as HTTP
import qualified Network.Wai as Wai
import qualified Network.Wai.Middleware.HttpAuth as Wai

import Control.Monad.Except (liftEither)
import Data.Either.Combinators (mapLeft)
import Data.List (lookup)
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import System.IO.Unsafe (unsafePerformIO)
import System.TimeIt (timeItT)

import PostgREST.ApiRequest (ApiRequest (..))
import PostgREST.AppState (AppState, getConfig, getJwtCacheState,
getTime)
import PostgREST.Auth.JwtCache (lookupJwtCache)
Expand Down Expand Up @@ -131,11 +126,12 @@ parseClaims AppConfig{..} jclaims@(JSON.Object mclaims) = do
walkJSPath x [] = x
walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest
walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter filterCond] = case filterCond of
EqualsCond txt -> findFirstMatch (==) txt ar
NotEqualsCond txt -> findFirstMatch (/=) txt ar
StartsWithCond txt -> findFirstMatch T.isPrefixOf txt ar
EndsWithCond txt -> findFirstMatch T.isSuffixOf txt ar
ContainsCond txt -> findFirstMatch T.isInfixOf txt ar
walkJSPath _ _ = Nothing

findFirstMatch matchWith pattern = foldr checkMatch Nothing
Expand All @@ -151,55 +147,21 @@ parseClaims AppConfig{..} jclaims@(JSON.Object mclaims) = do
-- impossible case - just added to please -Wincomplete-patterns
parseClaims _ _ = return AuthResult { authClaims = KM.empty, authRole = mempty }

-- | Validate authorization header.
-- Parse and store JWT claims for future use in the request.
middleware :: AppState -> Wai.Middleware
middleware appState app req respond = do
-- | Perform authentication and authorization
-- Parse JWT and return AuthResult
getAuthResult :: AppState -> ApiRequest -> IO (Either Error AuthResult)
getAuthResult appState ApiRequest{..} = do
conf <- getConfig appState
time <- getTime appState

let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req)
let ciHdrs = map (first CI.mk) iHeaders
token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization ciHdrs
parseJwt = runExceptT $ parseToken conf token time >>= parseClaims conf
jwtCacheState = getJwtCacheState appState

-- If ServerTimingEnabled -> calculate JWT validation time
-- If JwtCacheMaxLifetime -> cache JWT validation result
req' <- case (configServerTimingEnabled conf, configJwtCacheMaxLifetime conf) of
(True, 0) -> do
(dur, authResult) <- timeItT parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }

(True, maxLifetime) -> do
(dur, authResult) <- timeItT $ case token of
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
Nothing -> parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }

(False, 0) -> do
authResult <- parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }

(False, maxLifetime) -> do
authResult <- case token of
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
Nothing -> parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }

app req' respond

authResultKey :: Vault.Key (Either Error AuthResult)
authResultKey = unsafePerformIO Vault.newKey
{-# NOINLINE authResultKey #-}

getResult :: Wai.Request -> Maybe (Either Error AuthResult)
getResult = Vault.lookup authResultKey . Wai.vault

jwtDurKey :: Vault.Key Double
jwtDurKey = unsafePerformIO Vault.newKey
{-# NOINLINE jwtDurKey #-}

getJwtDur :: Wai.Request -> Maybe Double
getJwtDur = Vault.lookup jwtDurKey . Wai.vault

getRole :: Wai.Request -> Maybe BS.ByteString
getRole req = authRole <$> (rightToMaybe =<< getResult req)
case configJwtCacheMaxLifetime conf of
0 -> parseJwt -- If 0 then cache is diabled; no lookup
maxLifetime -> case token of
-- Lookup only if token found in header
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
Nothing -> parseJwt
14 changes: 6 additions & 8 deletions src/PostgREST/Logger.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ module PostgREST.Logger
, LoggerState
) where

import Control.AutoUpdate (defaultUpdateSettings,
mkAutoUpdate, updateAction)
import Control.Debounce
import qualified Data.ByteString.Char8 as BS
import Control.AutoUpdate (defaultUpdateSettings, mkAutoUpdate,
updateAction)
import Control.Debounce

import Data.Time (ZonedTime, defaultTimeLocale, formatTime,
getZonedTime)
Expand Down Expand Up @@ -56,15 +55,14 @@ logWithDebounce loggerState action = do
newDebouncer

-- TODO stop using this middleware to reuse the same "observer" pattern for all our logs
middleware :: LogLevel -> (Wai.Request -> Maybe BS.ByteString) -> Wai.Middleware
middleware logLevel getAuthRole =
middleware :: LogLevel -> Wai.Middleware
middleware logLevel =
unsafePerformIO $
Wai.mkRequestLogger Wai.defaultRequestLoggerSettings
{ Wai.outputFormat =
Wai.ApacheWithSettings $
Wai.defaultApacheSettings &
Wai.setApacheRequestFilter (\_ res -> shouldLogResponse logLevel $ Wai.responseStatus res) &
Wai.setApacheUserGetter getAuthRole
Wai.setApacheRequestFilter (\_ res -> shouldLogResponse logLevel $ Wai.responseStatus res)
Comment on lines -66 to +65
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now with this approach we can't get the role here, I think we should do this change as a separate break/fix?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not do the change at all. Why remove the role from the log? It was added there on purpose, I even worked upstream (IIRC wai-logger) to support it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I couldn't figure out yet how to get the role after removing the auth middleware. So, I removed it temporarily. I'll try to restore this change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me put it this way: I remember introducing the middleware precisely to be able to add the role to the log output. So maybe this refactor is a dead-end. But if you come with a nice way to do it and will overall make the code better, sure.

, Wai.autoFlush = True
, Wai.destination = Wai.Handle stdout
}
Expand Down
16 changes: 9 additions & 7 deletions test/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,8 @@ def test_log_level(level, defaultenv):
assert response.status_code == 200

output = sorted(postgrest.read_stdout(nlines=7))
for line in output:
print(line)

if level == "crit":
assert len(output) == 0
Expand All @@ -974,35 +976,35 @@ def test_log_level(level, defaultenv):
output[0],
)
assert re.match(
r'- - postgrest_test_anonymous \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"',
r'- - - \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a refactor, this is a regression. The user's role was supposed to be in the log here and that's not the case anymore.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sorry, forgot to mark this PR as draft. This changed isn't supposed to be done with this refactor. I changed this here to discuss about this.

output[1],
)
assert len(output) == 2
elif level == "info":
assert re.match(
r'- - - \[.+\] "GET / HTTP/1.1" 500 \d+ "" "python-requests/.+"',
r'- - - \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"',
output[0],
)
assert re.match(
r'- - postgrest_test_anonymous \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"',
r'- - - \[.+\] "GET / HTTP/1.1" 500 \d+ "" "python-requests/.+"',
output[1],
)
assert re.match(
r'- - postgrest_test_anonymous \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"',
r'- - - \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"',
output[2],
)
assert len(output) == 3
elif level == "debug":
assert re.match(
r'- - - \[.+\] "GET / HTTP/1.1" 500 \d+ "" "python-requests/.+"',
r'- - - \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"',
output[0],
)
assert re.match(
r'- - postgrest_test_anonymous \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"',
r'- - - \[.+\] "GET / HTTP/1.1" 500 \d+ "" "python-requests/.+"',
output[1],
)
assert re.match(
r'- - postgrest_test_anonymous \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"',
r'- - - \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"',
output[2],
)

Expand Down