Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion simple-hyper-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ categories = ["web-programming::http-client"]
edition = "2018"

[dependencies]
derive_more = { version = "2", features = ["is_variant", "try_unwrap", "unwrap"] }
futures-executor = "0.3"
futures-util = "0.3"
headers = "0.4"
Expand All @@ -26,5 +27,7 @@ tokio = { version = "1", features = ["rt", "macros", "net", "sync", "time"] }
tower-service = "0.3"

[dev-dependencies]
http-body-util = { version = "0.1", features = ["channel"] }
futures-util = "0.3"
http-body-util = { version = "0.1", features = ["channel"] }
httparse = "1"
test-case = "3"
100 changes: 72 additions & 28 deletions simple-hyper-client/src/async_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */

use crate::body::RequestBody;
use crate::connector::{ConnectorAdapter, NetworkConnector};
use crate::error::Error;
use crate::shared_body::SharedBody;
use crate::{HyperClient, HyperClientBuilder, Response};

use headers::{ContentLength, Header, HeaderMap, HeaderMapExt};
use headers::{Header, HeaderMap, HeaderMapExt};
use hyper::body::Body;
use hyper::{Method, Request, Uri};

use std::convert::{TryFrom, TryInto};
Expand All @@ -30,7 +31,7 @@ use std::time::Duration;
/// [hyper's `Client` type]: https://docs.rs/hyper-util/latest/hyper_util/client/legacy/struct.Client.html
#[derive(Clone)]
pub struct Client {
inner: Arc<HyperClient<ConnectorAdapter, SharedBody>>,
inner: Arc<HyperClient<ConnectorAdapter, RequestBody>>,
}

macro_rules! define_method_fn {
Expand Down Expand Up @@ -66,7 +67,7 @@ impl Client {

/// This method can be used instead of [Client::request]
/// if the caller already has a [Request].
pub async fn send(&self, request: Request<SharedBody>) -> Result<Response, Error> {
pub async fn send(&self, request: Request<RequestBody>) -> Result<Response, Error> {
Ok(self.inner.request(request).await?)
}

Expand Down Expand Up @@ -158,7 +159,7 @@ pub(crate) struct RequestDetails {
pub(crate) method: Method,
pub(crate) uri: Uri,
pub(crate) headers: HeaderMap,
pub(crate) body: Option<SharedBody>,
pub(crate) body: Option<RequestBody>,
}

impl fmt::Debug for RequestDetails {
Expand Down Expand Up @@ -187,32 +188,26 @@ impl RequestDetails {
Ok(client.inner.request(req).await?)
}

pub fn into_request(mut self) -> Result<Request<SharedBody>, Error> {
pub fn into_request(self) -> Result<Request<RequestBody>, Error> {
let can_have_body = match self.method {
// See RFC 7231 section 4.3
Method::GET | Method::HEAD | Method::DELETE => false,
_ => true,
};
let body = match can_have_body {
true => {
let body = self.body.unwrap_or_else(|| SharedBody::empty());
// NOTE: body cannot be chunked in this implementation, so we
// don't worry about chunked encoding here. But if this changes
// then we should not set `ContentLength` automatically if the
// request body is chunked, see RFC 7230 section 3.3.2.
self.headers.typed_insert(ContentLength(body.len() as u64));
body
}
false if self.body.is_some() => return Err(Error::BodyNotAllowed(self.method)),
false => SharedBody::empty(),
let body = if can_have_body {
self.body.unwrap_or_else(|| RequestBody::empty())
} else if self.body.is_some_and(|body| body.size_hint().lower() > 0) {
return Err(Error::BodyNotAllowed(self.method));
} else {
RequestBody::empty()
};
let mut req = Request::builder().method(self.method).uri(self.uri);
match req.headers_mut() {
Some(headers) => {
*headers = self.headers;
}
// There is an error in req, but the only way to extract the error is through `req.body()`
None => match req.body(SharedBody::empty()) {
None => match req.body(RequestBody::empty()) {
Err(e) => return Err(e.into()),
Ok(_) => {
panic!("request builder must have errors if `fn headers_mut()` returns None")
Expand Down Expand Up @@ -241,7 +236,7 @@ pub struct RequestBuilder<'a> {

impl<'a> RequestBuilder<'a> {
/// Set the request body.
pub fn body<B: Into<SharedBody>>(mut self, body: B) -> Self {
pub fn body<B: Into<RequestBody>>(mut self, body: B) -> Self {
self.details.body = Some(body.into());
self
}
Expand All @@ -264,7 +259,7 @@ impl<'a> RequestBuilder<'a> {
///
/// Prefer [RequestBuilder::send] unless you have a specific
/// need to get the resultant [Request].
pub fn build(self) -> Result<Request<SharedBody>, Error> {
pub fn build(self) -> Result<Request<RequestBody>, Error> {
self.details.into_request()
}

Expand Down Expand Up @@ -298,31 +293,35 @@ mod tests {
use super::*;
use crate::connector::HttpConnector;
use crate::util::to_bytes;
use headers::ContentType;
use headers::{ContentLength, ContentType};
use hyper::StatusCode;
use std::net::SocketAddr;
use test_case::test_case;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::oneshot;

const RESPONSE_OK: &str = "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!\r\n";
const RESPONSE_404: &str =
"HTTP/1.1 404 Not Found\r\nContent-Length: 23\r\n\r\nResource was not found.\r\n";

async fn test_http_server(resp: &'static str) -> SocketAddr {
async fn test_http_server(resp: &'static str, body_tx: oneshot::Sender<Vec<u8>>) -> SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut input = Vec::new();
stream.read(&mut input).await.unwrap();
stream.write_all(resp.as_bytes()).await.unwrap();
stream.read_to_end(&mut input).await.unwrap();
let _ = body_tx.send(input);
});
addr
}

#[tokio::test]
async fn http_client() {
let addr = test_http_server(RESPONSE_OK).await;
let (tx, rx) = oneshot::channel();
let addr = test_http_server(RESPONSE_OK, tx).await;
let url = format!("http://{}/", addr);

let connector = HttpConnector::new();
Expand All @@ -336,14 +335,59 @@ mod tests {
.await
.unwrap();

// Parse the request received by the server
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut request = httparse::Request::new(&mut headers);
let req_buf = rx.await.unwrap();
let body_idx = request.parse(&req_buf).unwrap().unwrap();
assert_eq!(request.method, Some("POST"));
assert_eq!(request.path, Some("/"));
assert_eq!(request.version, Some(1));
let content_length = request
.headers
.iter()
.find(|header| header.name == ContentLength::name())
.unwrap();
assert_eq!(content_length.value, "15".as_bytes());
assert_eq!(
str::from_utf8(&req_buf[body_idx..]).unwrap(),
"{\"key\":\"value\"}"
);

assert_eq!(response.status(), StatusCode::OK);
let body = to_bytes(response).await.unwrap();
assert_eq!(body, "Hello, world!".as_bytes());
let response_body = to_bytes(response).await.unwrap();
assert_eq!(response_body, "Hello, world!".as_bytes());
}

#[test_case(Some(r#"{"key":"value"}"#.into()), false; "non-empty body not allowed")]
#[test_case(Some("".into()), true; "empty body allowed")]
#[test_case(None, true; "without body allowed")]
#[tokio::test]
async fn get_request(body: Option<RequestBody>, expect_ok: bool) {
let (tx, _rx) = oneshot::channel();
let addr = test_http_server(RESPONSE_OK, tx).await;
let url = format!("http://{}/", addr);

let connector = HttpConnector::new();
let client = Client::with_connector(connector);
let mut builder = client.get(url).unwrap();

if let Some(body) = body {
builder = builder.header(ContentType::json()).body(body);
}

let result = builder.send().await;
if expect_ok {
result.unwrap();
} else {
assert_eq!(result.unwrap_err().unwrap_body_not_allowed(), Method::GET);
}
}

#[tokio::test]
async fn drop_client_before_response() {
let addr = test_http_server(RESPONSE_404).await;
let (tx, _rx) = oneshot::channel();
let addr = test_http_server(RESPONSE_404, tx).await;
let url = format!("http://{}/", addr);

let connector = HttpConnector::new();
Expand Down
4 changes: 2 additions & 2 deletions simple-hyper-client/src/blocking/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
use super::body::Body;
use super::Response;
use crate::async_client::{ClientBuilder as AsyncClientBuilder, RequestDetails};
use crate::body::RequestBody;
use crate::connector::NetworkConnector;
use crate::error::Error;
use crate::shared_body::SharedBody;

use futures_executor::block_on;
use headers::{Header, HeaderMap, HeaderMapExt};
Expand Down Expand Up @@ -206,7 +206,7 @@ pub struct RequestBuilder<'a> {

impl<'a> RequestBuilder<'a> {
/// Set the request body.
pub fn body<B: Into<SharedBody>>(mut self, body: B) -> Self {
pub fn body<B: Into<RequestBody>>(mut self, body: B) -> Self {
self.details.body = Some(body.into());
self
}
Expand Down
Loading