From a8e97ea154fdca237aaa9fa56b654a6000be1d8f Mon Sep 17 00:00:00 2001 From: Elad Kaplan Date: Tue, 7 Jan 2025 14:01:31 +0200 Subject: [PATCH] fix cors issues --- src/controller/middleware/cors.rs | 83 +++++++++++++++---------------- tests/controller/middlewares.rs | 6 ++- 2 files changed, 45 insertions(+), 44 deletions(-) diff --git a/src/controller/middleware/cors.rs b/src/controller/middleware/cors.rs index 990ffb457..c0d8c53d8 100644 --- a/src/controller/middleware/cors.rs +++ b/src/controller/middleware/cors.rs @@ -10,7 +10,7 @@ use std::time::Duration; use axum::Router as AXRouter; use serde::{Deserialize, Serialize}; use serde_json::json; -use tower_http::cors; +use tower_http::cors::{self, Any}; use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result}; @@ -38,14 +38,8 @@ pub struct Cors { pub vary: Vec, } -impl Default for Cors { - fn default() -> Self { - serde_json::from_value(json!({})).unwrap() - } -} - fn default_allow_origins() -> Vec { - vec!["any".to_string()] + vec!["*".to_string()] } fn default_allow_headers() -> Vec { @@ -64,19 +58,13 @@ fn default_vary_headers() -> Vec { ] } -impl Cors { - #[must_use] - pub fn empty() -> Self { - Self { - enable: true, - allow_headers: vec![], - allow_methods: vec![], - allow_origins: vec![], - allow_credentials: false, - max_age: None, - vary: vec![], - } +impl Default for Cors { + fn default() -> Self { + serde_json::from_value(json!({})).unwrap() } +} + +impl Cors { /// Creates cors layer /// /// # Errors @@ -93,35 +81,46 @@ impl Cors { /// In all of these cases, the error returned will be the result of the /// `parse` method of the corresponding type. pub fn cors(&self) -> Result { - let mut cors: cors::CorsLayer = cors::CorsLayer::permissive(); - - let mut list = vec![]; + let mut cors: cors::CorsLayer = cors::CorsLayer::new(); // testing CORS, assuming https://example.com in the allow list: // $ curl -v --request OPTIONS 'localhost:5150/api/_ping' -H 'Origin: https://example.com' -H 'Acces // look for '< access-control-allow-origin: https://example.com' in response. // if it doesn't appear (test with a bogus domain), it is not allowed. - for origin in &self.allow_origins { - list.push(origin.parse()?); - } - if !list.is_empty() { - cors = cors.allow_origin(list); + if self.allow_origins == default_allow_origins() { + cors = cors.allow_origin(Any); + } else { + let mut list = vec![]; + for origin in &self.allow_origins { + list.push(origin.parse()?); + } + if !list.is_empty() { + cors = cors.allow_origin(list); + } } - let mut list = vec![]; - for header in &self.allow_headers { - list.push(header.parse()?); - } - if !list.is_empty() { - cors = cors.allow_headers(list); + if self.allow_headers == default_allow_headers() { + cors = cors.allow_headers(Any); + } else { + let mut list = vec![]; + for header in &self.allow_headers { + list.push(header.parse()?); + } + if !list.is_empty() { + cors = cors.allow_headers(list); + } } - let mut list = vec![]; - for method in &self.allow_methods { - list.push(method.parse()?); - } - if !list.is_empty() { - cors = cors.allow_methods(list); + if self.allow_methods == default_allow_methods() { + cors = cors.allow_methods(Any); + } else { + let mut list = vec![]; + for method in &self.allow_methods { + list.push(method.parse()?); + } + if !list.is_empty() { + cors = cors.allow_methods(list); + } } let mut list = vec![]; @@ -192,7 +191,7 @@ mod tests { #[case] allow_methods: Option>, #[case] max_age: Option, ) { - let mut middleware = Cors::empty(); + let mut middleware = Cors::default(); if let Some(allow_headers) = allow_headers { middleware.allow_headers = allow_headers; } @@ -238,7 +237,7 @@ mod tests { #[tokio::test] async fn cors_options() { - let mut middleware = Cors::empty(); + let mut middleware = Cors::default(); middleware.allow_origins = vec![ "http://localhost:8080".to_string(), "http://example.com".to_string(), diff --git a/tests/controller/middlewares.rs b/tests/controller/middlewares.rs index a15fda375..fb9b107da 100644 --- a/tests/controller/middlewares.rs +++ b/tests/controller/middlewares.rs @@ -170,8 +170,10 @@ async fn cors( let mut ctx: AppContext = tests_cfg::app::get_app_context().await; - let mut middleware = Cors::empty(); - middleware.enable = enable; + let mut middleware = Cors { + enable, + ..Default::default() + }; if let Some(allow_headers) = allow_headers { middleware.allow_headers = allow_headers;