dogy_backend_api/middleware/auth/
layer.rs

1//! This module contains the axum middleware layer for authentication.
2use axum::extract::State;
3use axum::{extract::Request, http::header, middleware::Next, response::Response};
4use tracing::debug;
5use uuid::Uuid;
6
7use crate::AppState;
8
9use super::core::authenticate_user;
10use super::Error as AuthError;
11use crate::{Error, Result};
12
13/// Represents the metadata about the current user.
14///
15/// This is passed throughout the middlewares and also handlers.
16#[derive(Clone, Debug, PartialEq)]
17pub struct CurrentUser {
18    /// Clerk User ID of the current user.
19    pub user_id: String,
20
21    /// Clerk Role (normally None for regular users) of the current user.
22    #[allow(dead_code)]
23    pub role: Option<String>,
24
25    /// Database User ID of the current user. Initially, this should be None as we will retrieve
26    /// the internal ID from the database after authentication occurs.
27    pub internal_id: Option<Uuid>,
28}
29
30/// Axum middleware for authentication.
31///
32/// Accepts `Authorization` header from a request and validates it.
33/// Afterwards, it'll decode the JWT token and inject the user's details
34/// into the request.
35pub async fn auth_middleware(mut req: Request, next: Next) -> Result<Response> {
36    // Retrieve authorization header
37    let auth_header = req
38        .headers()
39        .get(header::AUTHORIZATION)
40        .and_then(|header| header.to_str().ok())
41        .ok_or(Error::Auth(AuthError::MissingAuthHeader))?
42        .strip_prefix("Bearer ")
43        .ok_or(Error::Auth(AuthError::NoBearerPrefix))?;
44
45    let current_user = authenticate_user(auth_header)?;
46
47    // Inject the current user's details for both request and response.
48    // Injecting it to response is necessary in order for it to be extracted later on in
49    // response_mapper or logging middleware.
50    req.extensions_mut().insert(current_user.clone());
51    let mut res = next.run(req).await;
52    res.extensions_mut().insert(current_user);
53    Ok(res)
54}
55
56/// Axum middleware for retrieving the internal ID for the current user.
57///
58/// If it passes through authentication but fails to retrieve the internal ID of the user,
59/// it'll simply return a [`AuthError::UserNotFound`] error.
60pub async fn get_internal_id(
61    State(state): State<AppState>,
62    mut req: Request,
63    next: Next,
64) -> Result<Response> {
65    let current_user = req.extensions_mut().get::<CurrentUser>().unwrap().clone();
66
67    let query = "SELECT id FROM users where external_id = $1";
68    let user_id = &current_user.user_id;
69    let internal_id: Uuid = sqlx::query_scalar(query)
70        .bind(user_id)
71        .fetch_one(&*state.db)
72        .await
73        .unwrap_or(None)
74        .ok_or(Error::Auth(AuthError::UserNotFound {
75            user_id: user_id.to_string(),
76        }))?;
77
78    let updated_user = CurrentUser {
79        internal_id: Some(internal_id),
80        ..current_user
81    };
82    debug!("Current user: {:?}", &updated_user);
83    req.extensions_mut().insert(updated_user);
84    Ok(next.run(req).await)
85}
86
87#[cfg(test)]
88mod test {
89    use super::*;
90    use crate::middleware::{log::layer::log_middleware, test::*};
91    use axum::middleware;
92    use axum_test::TestServer;
93    use reqwest::StatusCode;
94    use serde_json::json;
95    use std::env;
96    use testcontainers::ContainerAsync;
97    use testcontainers_modules::postgres::Postgres;
98
99    #[cfg(test)]
100    fn setup_test_server_with_auth() -> TestServer {
101        let test_route = setup_test_router()
102            .route_layer(middleware::from_fn(auth_middleware))
103            .layer(middleware::map_response(log_middleware));
104
105        TestServer::new(test_route).unwrap()
106    }
107
108    #[cfg(test)]
109    fn setup_test_server_with_assert_user() -> TestServer {
110        let test_route = setup_test_router()
111            .layer(middleware::from_fn(assert_current_user_in_extensions_mw))
112            .route_layer(middleware::from_fn(auth_middleware))
113            .layer(middleware::map_response(log_middleware));
114
115        TestServer::new(test_route).unwrap()
116    }
117
118    #[cfg(test)]
119    async fn setup_test_server_with_state(
120        create_user: bool,
121    ) -> (TestServer, ContainerAsync<Postgres>) {
122        let (state, container) = setup_test_db().await;
123        if create_user {
124            sqlx::query(
125                r#"INSERT INTO users (name, external_id, timezone, gender, has_onboarded)
126                VALUES ('Test User', 'user_2ruHSXCzfIRreR2tpttVQBl512a', 'Europe/Stockholm', 'male', true);
127                "#,
128            )
129            .execute(&*state.db)
130            .await
131            .unwrap();
132        }
133        let test_route = setup_test_router()
134            .layer(middleware::from_fn(assert_current_user_full_in_ext_mw))
135            .layer(middleware::from_fn_with_state(state, get_internal_id))
136            .layer(middleware::from_fn(auth_middleware))
137            .layer(middleware::map_response(log_middleware));
138
139        println!(
140            "host: {}",
141            container.get_host_port_ipv4(5432).await.unwrap()
142        );
143
144        (TestServer::new(test_route).unwrap(), container)
145    }
146
147    #[cfg(test)]
148    async fn assert_current_user_in_extensions_mw(req: Request, next: Next) -> Result<Response> {
149        let current_user = req.extensions().get::<CurrentUser>().unwrap();
150        assert_eq!(current_user.user_id, "user_2ruHSXCzfIRreR2tpttVQBl512a");
151        assert_eq!(current_user.role, None);
152        assert_eq!(current_user.internal_id, None);
153        Ok(next.run(req).await)
154    }
155
156    #[cfg(test)]
157    async fn assert_current_user_full_in_ext_mw(req: Request, next: Next) -> Result<Response> {
158        let current_user = req.extensions().get::<CurrentUser>().unwrap();
159        assert_eq!(current_user.user_id, "user_2ruHSXCzfIRreR2tpttVQBl512a");
160        assert_eq!(current_user.role, None);
161        assert!(current_user.internal_id.is_some());
162        Ok(next.run(req).await)
163    }
164
165    #[tokio::test]
166    async fn test_auth_middleware_ok() {
167        let app = setup_test_server_with_auth();
168        let _ = dotenv::from_filename(".env.test");
169        let jwt_token = env::var("JWT_TOKEN").unwrap();
170
171        let response = app
172            .get("/")
173            .add_header("Authorization", format!("Bearer {}", jwt_token))
174            .await;
175
176        response.assert_status(StatusCode::OK);
177        response.assert_text("Middleware test succeeded");
178    }
179
180    #[tokio::test]
181    async fn test_auth_middleware_missing_auth_header_err() {
182        let app = setup_test_server_with_auth();
183
184        let response = app.get("/").await;
185
186        response.assert_status(StatusCode::UNAUTHORIZED);
187        response.assert_header("content-type", "application/json");
188        response.assert_json(&json!({
189            "status": "error",
190            "code": "MISSING_AUTH_HEADER"
191        }));
192    }
193
194    #[tokio::test]
195    async fn test_auth_middleware_no_bearer_prefix_err() {
196        let app = setup_test_server_with_auth();
197
198        let response = app.get("/").add_header("Authorization", "some_token").await;
199
200        response.assert_status(StatusCode::UNAUTHORIZED);
201        response.assert_header("content-type", "application/json");
202        response.assert_json(&json!({
203            "status": "error",
204            "code": "NO_BEARER_PREFIX"
205        }));
206    }
207
208    #[tokio::test]
209    async fn test_auth_middleware_invalid_token_err() {
210        let app = setup_test_server_with_auth();
211
212        let response = app
213            .get("/")
214            .add_header("Authorization", "Bearer invalid_token")
215            .await;
216
217        response.assert_status(StatusCode::UNAUTHORIZED);
218        response.assert_header("content-type", "application/json");
219        response.assert_json(&json!({
220            "status": "error",
221            "code": "INVALID_CREDENTIALS"
222        }));
223    }
224
225    #[tokio::test]
226    async fn test_auth_middleware_current_user_ext_ok() {
227        let app = setup_test_server_with_assert_user();
228        let _ = dotenv::from_filename(".env.test");
229        let jwt_token = env::var("JWT_TOKEN").unwrap();
230
231        let response = app
232            .get("/")
233            .add_header("Authorization", format!("Bearer {}", jwt_token))
234            .await;
235
236        response.assert_status(StatusCode::OK);
237        response.assert_text("Middleware test succeeded");
238    }
239
240    #[tokio::test]
241    async fn test_get_internal_id_mw_ok() {
242        let (app, container) = setup_test_server_with_state(true).await;
243        let _ = dotenv::from_filename(".env.test");
244        let jwt_token = env::var("JWT_TOKEN").unwrap();
245
246        let response = app
247            .get("/")
248            .add_header("Authorization", format!("Bearer {}", jwt_token))
249            .await;
250
251        response.assert_status(StatusCode::OK);
252        response.assert_text("Middleware test succeeded");
253        debug!(
254            "Host: {}",
255            container.get_host_port_ipv4(5432).await.unwrap()
256        );
257    }
258
259    #[tokio::test]
260    async fn test_get_internal_id_mw_user_not_found_err() {
261        let (app, container) = setup_test_server_with_state(false).await;
262        let _ = dotenv::from_filename(".env.test");
263        let jwt_token = env::var("JWT_TOKEN").unwrap();
264
265        let response = app
266            .get("/")
267            .add_header("Authorization", format!("Bearer {}", jwt_token))
268            .await;
269
270        response.assert_status(StatusCode::NOT_FOUND);
271        response.assert_header("content-type", "application/json");
272        response.assert_json(&json!({
273            "status": "error",
274            "code": "USER_NOT_FOUND",
275            "details": {
276                "user_id": "user_2ruHSXCzfIRreR2tpttVQBl512a"
277            }
278        }));
279        debug!(
280            "Host: {}",
281            container.get_host_port_ipv4(5432).await.unwrap()
282        );
283    }
284}