diff --git a/mongoose.c b/mongoose.c index 283e98f5e4..7adeffe419 100644 --- a/mongoose.c +++ b/mongoose.c @@ -6787,6 +6787,8 @@ struct connstate { uint64_t timer; // TCP timer (see 'ttype' below) uint32_t acked; // Last ACK-ed number size_t unacked; // Not acked bytes + uint32_t maxseq; // Max send seq (ack + window) + uint16_t win; // destination current window size uint16_t dmss; // destination MSS (from TCP opts) uint8_t mac[sizeof(struct mg_l2addr)]; // Peer hw address uint8_t ttype; // Timer type: @@ -8032,6 +8034,7 @@ static struct mg_connection *accept_conn(struct mg_connection *lsn, s = (struct connstate *) (c + 1); s->dmss = mss; // from options in client SYN s->seq = mg_ntohl(pkt->tcp->ack), s->ack = mg_ntohl(pkt->tcp->seq); + s->win = mg_ntohs(pkt->tcp->win), s->maxseq = (uint32_t)(s->seq + s->win); #if MG_ENABLE_IPV6 if (lsn->loc.is_ip6) { c->rem.addr.ip6[0] = pkt->ip6->src[0], @@ -8124,10 +8127,13 @@ long mg_io_send(struct mg_connection *c, const void *buf, size_t len) { len = trim_len(c, len); if (c->is_udp) { if (!udp_send(c, buf, len)) return MG_IO_WAIT; - } else { // TCP, cap to peer's MSS + } else { // TCP, cap to peer's MSS and check window struct mg_tcpip_if *ifp = c->mgr->ifp; size_t sent; + uint32_t room = s->maxseq - s->seq; + if (room == 0) return MG_IO_WAIT; if (len > s->dmss) len = s->dmss; // RFC-6691: reduce if sending opts + if ((uint32_t) len > room) len = room; sent = tx_tcp(ifp, s->mac, &c->loc, &c->rem, TH_PUSH | TH_ACK, mg_htonl(s->seq), mg_htonl(s->ack), buf, len); if (sent == 0) { @@ -8162,6 +8168,12 @@ static void handle_tls_recv(struct mg_connection *c) { } } +static void handle_ack(struct connstate *s, uint32_t ackno, uint16_t win) { + if (ackno < (s->seq - s->win) || ackno > s->seq) return; + s->maxseq = (uint32_t)(ackno + win); + s->win = win; +} + static void read_conn(struct mg_connection *c, struct pkt *pkt) { struct connstate *s = (struct connstate *) (c + 1); struct mg_iobuf *io = c->is_tls ? &c->rtls : &c->recv; @@ -8201,6 +8213,8 @@ static void read_conn(struct mg_connection *c, struct pkt *pkt) { mg_htonl(s->ack), NULL, 0); return; // no data to process } else if (pkt->pay.len == 0) { // this is an ACK + if (pkt->tcp->flags & TH_ACK) + handle_ack(s, mg_ntohl(pkt->tcp->ack), mg_ntohs(pkt->tcp->win)); if (s->fin_rcvd && s->ttype == MIP_TTYPE_FIN) s->twclosure = true; return; // no data to process } else if (seq != s->ack) { @@ -8218,6 +8232,8 @@ static void read_conn(struct mg_connection *c, struct pkt *pkt) { mg_error(c, "oom"); return; // drop it } + if (pkt->tcp->flags & TH_ACK) + handle_ack(s, mg_ntohl(pkt->tcp->ack), mg_ntohs(pkt->tcp->win)); // Copy TCP payload into the IO buffer. If the connection is plain text, // we copy to c->recv. If the connection is TLS, this data is encrypted, // therefore we copy that encrypted data to the c->rtls iobuffer instead, @@ -8336,6 +8352,7 @@ static void rx_tcp(struct mg_tcpip_if *ifp, struct pkt *pkt) { if (!handle_opt(s, pkt->tcp, pkt->ip6 != NULL)) return; // process options (MSS) s->seq = mg_ntohl(pkt->tcp->ack), s->ack = mg_ntohl(pkt->tcp->seq) + 1; + s->win = mg_ntohs(pkt->tcp->win), s->maxseq = (uint32_t)(s->seq + s->win); tx_tcp_ctrlresp(ifp, pkt, TH_ACK, pkt->tcp->ack); c->is_connecting = 0; // Client connected settmout(c, MIP_TTYPE_KEEPALIVE); diff --git a/src/net_builtin.c b/src/net_builtin.c index bea2bc7272..1598d247c4 100644 --- a/src/net_builtin.c +++ b/src/net_builtin.c @@ -23,6 +23,8 @@ struct connstate { uint64_t timer; // TCP timer (see 'ttype' below) uint32_t acked; // Last ACK-ed number size_t unacked; // Not acked bytes + uint32_t maxseq; // Max send seq (ack + window) + uint16_t win; // destination current window size uint16_t dmss; // destination MSS (from TCP opts) uint8_t mac[sizeof(struct mg_l2addr)]; // Peer hw address uint8_t ttype; // Timer type: @@ -1268,6 +1270,7 @@ static struct mg_connection *accept_conn(struct mg_connection *lsn, s = (struct connstate *) (c + 1); s->dmss = mss; // from options in client SYN s->seq = mg_ntohl(pkt->tcp->ack), s->ack = mg_ntohl(pkt->tcp->seq); + s->win = mg_ntohs(pkt->tcp->win), s->maxseq = (uint32_t)(s->seq + s->win); #if MG_ENABLE_IPV6 if (lsn->loc.is_ip6) { c->rem.addr.ip6[0] = pkt->ip6->src[0], @@ -1360,10 +1363,13 @@ long mg_io_send(struct mg_connection *c, const void *buf, size_t len) { len = trim_len(c, len); if (c->is_udp) { if (!udp_send(c, buf, len)) return MG_IO_WAIT; - } else { // TCP, cap to peer's MSS + } else { // TCP, cap to peer's MSS and check window struct mg_tcpip_if *ifp = c->mgr->ifp; size_t sent; + uint32_t room = s->maxseq - s->seq; + if (room == 0) return MG_IO_WAIT; if (len > s->dmss) len = s->dmss; // RFC-6691: reduce if sending opts + if ((uint32_t) len > room) len = room; sent = tx_tcp(ifp, s->mac, &c->loc, &c->rem, TH_PUSH | TH_ACK, mg_htonl(s->seq), mg_htonl(s->ack), buf, len); if (sent == 0) { @@ -1398,6 +1404,12 @@ static void handle_tls_recv(struct mg_connection *c) { } } +static void handle_ack(struct connstate *s, uint32_t ackno, uint16_t win) { + if (ackno < (s->seq - s->win) || ackno > s->seq) return; + s->maxseq = (uint32_t)(ackno + win); + s->win = win; +} + static void read_conn(struct mg_connection *c, struct pkt *pkt) { struct connstate *s = (struct connstate *) (c + 1); struct mg_iobuf *io = c->is_tls ? &c->rtls : &c->recv; @@ -1437,6 +1449,8 @@ static void read_conn(struct mg_connection *c, struct pkt *pkt) { mg_htonl(s->ack), NULL, 0); return; // no data to process } else if (pkt->pay.len == 0) { // this is an ACK + if (pkt->tcp->flags & TH_ACK) + handle_ack(s, mg_ntohl(pkt->tcp->ack), mg_ntohs(pkt->tcp->win)); if (s->fin_rcvd && s->ttype == MIP_TTYPE_FIN) s->twclosure = true; return; // no data to process } else if (seq != s->ack) { @@ -1454,6 +1468,8 @@ static void read_conn(struct mg_connection *c, struct pkt *pkt) { mg_error(c, "oom"); return; // drop it } + if (pkt->tcp->flags & TH_ACK) + handle_ack(s, mg_ntohl(pkt->tcp->ack), mg_ntohs(pkt->tcp->win)); // Copy TCP payload into the IO buffer. If the connection is plain text, // we copy to c->recv. If the connection is TLS, this data is encrypted, // therefore we copy that encrypted data to the c->rtls iobuffer instead, @@ -1572,6 +1588,7 @@ static void rx_tcp(struct mg_tcpip_if *ifp, struct pkt *pkt) { if (!handle_opt(s, pkt->tcp, pkt->ip6 != NULL)) return; // process options (MSS) s->seq = mg_ntohl(pkt->tcp->ack), s->ack = mg_ntohl(pkt->tcp->seq) + 1; + s->win = mg_ntohs(pkt->tcp->win), s->maxseq = (uint32_t)(s->seq + s->win); tx_tcp_ctrlresp(ifp, pkt, TH_ACK, pkt->tcp->ack); c->is_connecting = 0; // Client connected settmout(c, MIP_TTYPE_KEEPALIVE); diff --git a/test/mip_test.c b/test/mip_test.c index ca63981bec..e7fc338e20 100644 --- a/test/mip_test.c +++ b/test/mip_test.c @@ -5,6 +5,8 @@ #include "driver_mock.c" +#define TCP_TEST_WIN 3000 // arbitrary peer window size to test txwindow + static int s_num_tests = 0; static bool s_error = false; static int s_sent_fragment = 0; @@ -116,6 +118,17 @@ static void tcpclosure_fn(struct mg_connection *c, int ev, void *ev_data) { (void) c, (void) ev_data; } +static void txwindow_fn(struct mg_connection *c, int ev, void *ev_data) { + if (ev == MG_EV_ACCEPT) { + char bigdata[2 * TCP_TEST_WIN + 256]; + *(int *) c->fn_data = 0; + mg_send(c, bigdata, sizeof(bigdata)); + } else if (ev == MG_EV_WRITE) { + ++(*(int *) c->fn_data); + } + (void) c, (void) ev_data; +} + static void client_fn(struct mg_connection *c, int ev, void *ev_data) { if (ev == MG_EV_ERROR || ev == MG_EV_CONNECT) (*(int *) c->fn_data) = ev; (void) c, (void) ev_data; @@ -139,7 +152,7 @@ static void frag_send_fn(struct mg_connection *c, int ev, void *ev_data) { if (ev == MG_EV_POLL) { if (!s_sent) { struct connstate *s = (struct connstate *) (c + 1); - s->dmss = 1500; // mock set some destination MSS way larger + s->dmss = 1500, s->maxseq = 1500; // mock set dest MSS and win way larger c->send.len = 1200; // setting TCP payload size s_sent = true; } @@ -214,6 +227,7 @@ static void create_tcp_seg(struct eth *e, struct ipp *ipp, uint32_t seq, t.ack = mg_htonl(ack); t.sport = mg_htons(sport); t.dport = mg_htons(dport); + t.win = mg_htons(TCP_TEST_WIN); t.off = (uint8_t) ((sizeof(t) / 4) << 4) + (uint8_t) ((opts_len / 4) << 4); memcpy(s_driver_data.buf, e, sizeof(*e)); #if MG_ENABLE_IPV6 @@ -898,6 +912,54 @@ static void test_tcp_retransmit(void) { mg_mgr_free(&mgr); } + +static void test_tcp_txwindow(void) { + struct mg_mgr mgr; + struct eth e; + struct ip ip; + struct ipp ipp; + struct tcp *t = (struct tcp *) (s_driver_data.buf + sizeof(e) + sizeof(ip)); + int count = 0, stallcount; + uint32_t seq; + //bool response_recv = true; + struct mg_tcpip_driver driver; + struct mg_tcpip_if mif; + + ipp.ip4 = &ip; + ipp.ip6 = NULL; + + init_tcp_tests(&mgr, &e, &ipp, &driver, &mif, txwindow_fn); + mgr.conns->fn_data = &count; + init_tcp_handshake(&e, &ipp, &mgr); // starts with seq_no=1000, ackno=2 + ASSERT((t->seq == mg_htonl(2))); + do { + while (!received_response(&s_driver_data)) mg_mgr_poll(&mgr, 0); + seq = (uint32_t)(mg_htonl(t->seq) + s_driver_data.len - (size_t)((char *)((uint32_t *)t + (t->off >> 4)) - s_driver_data.buf)); + } while (seq < (TCP_TEST_WIN + 2)); + stallcount = count; + mg_mgr_poll(&mgr, 0), s_driver_data.len = 0; + mg_mgr_poll(&mgr, 0), s_driver_data.len = 0; + ASSERT((stallcount == count)); + s_driver_data.tx_ready = false; + create_tcp_simpleseg(&e, &ipp, 1001, seq - TCP_TEST_WIN/2, TH_ACK, 0); // send ACK for half window + while (!received_response(&s_driver_data)) mg_mgr_poll(&mgr, 0); + ASSERT((stallcount < count)); + do { + while (!received_response(&s_driver_data)) mg_mgr_poll(&mgr, 0); + seq = (uint32_t)(mg_htonl(t->seq) + s_driver_data.len - (size_t)((char *)((uint32_t *)t + (t->off >> 4)) - s_driver_data.buf)); + } while (seq < (TCP_TEST_WIN + TCP_TEST_WIN/2 + 2)); + stallcount = count; + mg_mgr_poll(&mgr, 0), s_driver_data.len = 0; + mg_mgr_poll(&mgr, 0), s_driver_data.len = 0; + ASSERT((stallcount == count)); + s_driver_data.tx_ready = false; + create_tcp_simpleseg(&e, &ipp, 1001, seq, TH_ACK, 0); // send ACK + while (!received_response(&s_driver_data)) mg_mgr_poll(&mgr, 0); + ASSERT((stallcount < count)); + s_driver_data.len = 0; + mg_mgr_free(&mgr); +} + static void test_frag_recv_path(void) { struct mg_mgr mgr; struct eth e; @@ -1058,6 +1120,7 @@ static void test_tcp(bool ipv6) { if (!ipv6) { test_tcp_backlog(); test_tcp_retransmit(); + test_tcp_txwindow(); } }