diff --git a/components/net/http_loader.rs b/components/net/http_loader.rs index de066b8248e1..fe0ff7a1a648 100644 --- a/components/net/http_loader.rs +++ b/components/net/http_loader.rs @@ -23,6 +23,7 @@ use hyper::header::{Authorization, Basic, CacheControl, CacheDirective, ContentE use hyper::header::{ContentLength, Encoding, Header, Headers, Host, IfMatch, IfRange}; use hyper::header::{IfUnmodifiedSince, IfModifiedSince, IfNoneMatch, Location, Pragma, Quality}; use hyper::header::{QualityItem, Referer, SetCookie, UserAgent, qitem}; +use hyper::header::Origin as HyperOrigin; use hyper::method::Method; use hyper::net::Fresh; use hyper::status::StatusCode; @@ -789,6 +790,15 @@ fn http_redirect_fetch(request: Rc, main_fetch(request, cache, cors_flag, true, target, done_chan, context) } +fn try_url_origin_to_hyper_origin(url_origin: UrlOrigin) -> Option { + match url_origin { + // TODO Set "Origin: null" when hyper supports it + UrlOrigin::Opaque(_) => None, + UrlOrigin::Tuple(scheme, host, port) => + Some(HyperOrigin::new(scheme, host.to_string(), Some(port))) + } +} + /// [HTTP network or cache fetch](https://fetch.spec.whatwg.org#http-network-or-cache-fetch) fn http_network_or_cache_fetch(request: Rc, authentication_fetch_flag: bool, @@ -846,11 +856,20 @@ fn http_network_or_cache_fetch(request: Rc, unreachable!() }; - // Step 9 - if cors_flag || - (*http_request.method.borrow() != Method::Get && *http_request.method.borrow() != Method::Head) { - // TODO update this when https://github.com/hyperium/hyper/pull/691 is finished - // http_request.headers.borrow_mut().set_raw("origin", origin); + // Step 7 + if http_request.omit_origin_header.get() == false { + let method = http_request.method.borrow(); + if cors_flag || (*method != Method::Get && *method != Method::Head) { + match *http_request.origin.borrow() { + Origin::Origin(ref url_origin) => + match try_url_origin_to_hyper_origin(url_origin.clone()) { + Some(hyper_origin) => http_request.headers.borrow_mut().set(hyper_origin), + None => (), + }, + // TODO Set origin to client origin when request has client object + Origin::Client => (), + } + } } // Step 10 diff --git a/tests/unit/net/http_loader.rs b/tests/unit/net/http_loader.rs index 4dc165f75053..437bd0e77822 100644 --- a/tests/unit/net/http_loader.rs +++ b/tests/unit/net/http_loader.rs @@ -13,8 +13,8 @@ use flate2::Compression; use flate2::write::{DeflateEncoder, GzEncoder}; use hyper::LanguageTag; use hyper::header::{Accept, AcceptEncoding, ContentEncoding, ContentLength, Cookie as CookieHeader}; -use hyper::header::{AcceptLanguage, Authorization, Basic, Date}; -use hyper::header::{Encoding, Headers, Host, Location, Quality, QualityItem, SetCookie, qitem}; +use hyper::header::{AcceptLanguage, AccessControlAllowOrigin, Authorization, Basic, Date}; +use hyper::header::{Encoding, Headers, Host, Location, Origin, Quality, QualityItem, SetCookie, qitem}; use hyper::header::{StrictTransportSecurity, UserAgent}; use hyper::method::Method; use hyper::mime::{Mime, SubLevel, TopLevel}; @@ -28,12 +28,13 @@ use net::cookie_storage::CookieStorage; use net::resource_thread::AuthCacheEntry; use net_traits::{CookieSource, NetworkError}; use net_traits::hosts::replace_host_table; -use net_traits::request::{Request, RequestInit, CredentialsMode, Destination}; +use net_traits::request::{Request, RequestInit, RequestMode, CredentialsMode, Destination}; use net_traits::response::ResponseBody; use new_fetch_context; use servo_url::ServoUrl; use std::collections::HashMap; use std::io::{Read, Write}; +use std::str::FromStr; use std::sync::{Arc, Mutex, RwLock, mpsc}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::Receiver; @@ -146,8 +147,13 @@ fn test_check_default_headers_loaded_in_every_request() { assert!(response.status.unwrap().is_success()); // Testing for method.POST - headers.set(ContentLength(0 as u64)); - *expected_headers.lock().unwrap() = Some(headers.clone()); + let mut post_headers = headers.clone(); + post_headers.set(ContentLength(0 as u64)); + let url_str = url.as_str(); + // request gets header "Origin: http://example.com" but expected_headers has + // "Origin: http://example.com/" which do not match for equality so strip trailing '/' + post_headers.set(Origin::from_str(&url_str[..url_str.len()-1]).unwrap()); + *expected_headers.lock().unwrap() = Some(post_headers); let request = Request::from_init(RequestInit { url: url.clone(), method: Method::Post, @@ -1193,3 +1199,61 @@ fn test_cookies_blocked() { assert!(response.status.unwrap().is_success()); } + +#[test] +fn test_origin_set() { + let origin_header = Arc::new(Mutex::new(None)); + let origin_header_clone = origin_header.clone(); + let handler = move |request: HyperRequest, mut resp: HyperResponse| { + let origin_header_clone = origin_header.clone(); + resp.headers_mut().set(AccessControlAllowOrigin::Any); + match request.headers.get::() { + None => assert_eq!(origin_header_clone.lock().unwrap().take(), None), + Some(h) => assert_eq!(*h, origin_header_clone.lock().unwrap().take().unwrap()), + } + }; + let (mut server, url) = make_server(handler); + + let mut origin = Origin::new(url.scheme(), url.host_str().unwrap(), url.port()); + *origin_header_clone.lock().unwrap() = Some(origin.clone()); + let request = Request::from_init(RequestInit { + url: url.clone(), + method: Method::Post, + body: None, + origin: url.clone(), + .. RequestInit::default() + }); + let response = fetch(request, None); + assert!(response.status.unwrap().is_success()); + + let origin_url = ServoUrl::parse("http://example.com").unwrap(); + origin = Origin::new(origin_url.scheme(), origin_url.host_str().unwrap(), origin_url.port()); + // Test Origin header is set on Get request with CORS mode + let request = Request::from_init(RequestInit { + url: url.clone(), + method: Method::Get, + mode: RequestMode::CorsMode, + body: None, + origin: origin_url.clone(), + .. RequestInit::default() + }); + + *origin_header_clone.lock().unwrap() = Some(origin.clone()); + let response = fetch(request, None); + assert!(response.status.unwrap().is_success()); + + // Test Origin header is not set on method Head + let request = Request::from_init(RequestInit { + url: url.clone(), + method: Method::Head, + body: None, + origin: url.clone(), + .. RequestInit::default() + }); + + *origin_header_clone.lock().unwrap() = None; + let response = fetch(request, None); + assert!(response.status.unwrap().is_success()); + + let _ = server.close(); +}