diff --git a/webrtc/modules/remote_bitrate_estimator/test/packet_sender.cc b/webrtc/modules/remote_bitrate_estimator/test/packet_sender.cc index 80a9710f7..3e0073a41 100644 --- a/webrtc/modules/remote_bitrate_estimator/test/packet_sender.cc +++ b/webrtc/modules/remote_bitrate_estimator/test/packet_sender.cc @@ -276,16 +276,25 @@ void TcpSender::RunFor(int64_t time_ms, Packets* in_out) { UpdateCongestionControl(fb); SendPackets(in_out); } + + for (auto it = in_flight_.begin(); it != in_flight_.end();) { + if (it->time_ms < now_ms_ - 1000) + in_flight_.erase(it++); + else + ++it; + } + SendPackets(in_out); now_ms_ += time_ms; } void TcpSender::SendPackets(Packets* in_out) { int cwnd = ceil(cwnd_); - int packets_to_send = std::max(cwnd - in_flight_, 0); + int packets_to_send = std::max(cwnd - static_cast(in_flight_.size()), 0); if (packets_to_send > 0) { Packets generated = GeneratePackets(packets_to_send); - in_flight_ += static_cast(generated.size()); + for (Packet* packet : generated) + in_flight_.insert(InFlight(*static_cast(packet))); in_out->merge(generated, DereferencingComparator); } } @@ -296,10 +305,12 @@ void TcpSender::UpdateCongestionControl(const FeedbackPacket* fb) { ack_received_ = true; uint16_t expected = tcp_fb->acked_packets().back() - last_acked_seq_num_; - in_flight_ -= expected; uint16_t missing = expected - static_cast(tcp_fb->acked_packets().size()); + for (uint16_t ack_seq_num : tcp_fb->acked_packets()) + in_flight_.erase(InFlight(ack_seq_num, now_ms_)); + if (missing > 0) { cwnd_ /= 2.0f; in_slow_start_ = false; diff --git a/webrtc/modules/remote_bitrate_estimator/test/packet_sender.h b/webrtc/modules/remote_bitrate_estimator/test/packet_sender.h index 4bdd5392c..0c579ed09 100644 --- a/webrtc/modules/remote_bitrate_estimator/test/packet_sender.h +++ b/webrtc/modules/remote_bitrate_estimator/test/packet_sender.h @@ -109,7 +109,6 @@ class TcpSender : public PacketSender { now_ms_(0), in_slow_start_(false), cwnd_(10), - in_flight_(0), ack_received_(false), last_acked_seq_num_(0), next_sequence_number_(0) {} @@ -120,6 +119,25 @@ class TcpSender : public PacketSender { int GetFeedbackIntervalMs() const override { return 10; } private: + struct InFlight { + public: + InFlight(const MediaPacket& packet) + : sequence_number(packet.header().sequenceNumber), + time_ms(packet.send_time_us() / 1000) {} + + InFlight(uint16_t seq_num, int64_t now_ms) + : sequence_number(seq_num), time_ms(now_ms) {} + + bool operator<(const InFlight& rhs) const { + return sequence_number < rhs.sequence_number; + } + + uint16_t sequence_number; // Sequence number of a packet in flight, or a + // packet which has just been acked. + int64_t time_ms; // Time of when the packet left the sender, or when the + // ack was received. + }; + void SendPackets(Packets* in_out); void UpdateCongestionControl(const FeedbackPacket* fb); Packets GeneratePackets(size_t num_packets); @@ -127,7 +145,7 @@ class TcpSender : public PacketSender { int64_t now_ms_; bool in_slow_start_; float cwnd_; - int in_flight_; + std::set in_flight_; bool ack_received_; uint16_t last_acked_seq_num_; uint16_t next_sequence_number_;