Commit ba7517e4 authored by StevenWdV's avatar StevenWdV

Made sequence numbers cycle to support files over 65KB. Fixed bug that causes...

Made sequence numbers cycle to support files over 65KB. Fixed bug that causes SYN-ACK to not necessarily be ACKed. Fixed bug in enqueue_data.
parent a0b9e219
......@@ -13,6 +13,7 @@ from typing import *
header_format = "!HHHHBBHI"
header_size = 16
payload_size = 1000
max_seq = 0xffFF
class _Flags:
......@@ -117,7 +118,6 @@ class _RemoteInit:
# TODO more timeouts
# TODO fast retransmit
# TODO cycle sequence numbers
# TODO? dynamic windows size
class _Stream:
def __init__(self, binding: "Binding", local_port: int,
......@@ -144,7 +144,7 @@ class _Stream:
self.deb = f"{self.binding.local_udp_addr}|{self.local_port}->{self.remote_port}: "
self.first_send_syn_nr = random.randint(0, 0xff)
self.first_send_syn_nr = random.randint(0, max_seq)
if remote_init is None: # We initiate the connection
self.send_syn_nr = self.first_send_syn_nr # Next to be sent
......@@ -158,7 +158,7 @@ class _Stream:
self.remote_window_size = remote_init.window_size
self.recv_syn_nr = remote_init.syn_nr # Last received
self.recv_ack_nr = remote_init.syn_nr + 1 # First still to be received
self.recv_ack_nr = self.__next_seq(remote_init.syn_nr) # First still to be received
self.send_syn_nr = self.first_send_syn_nr # Next to be sent
self.send_ack_nr = self.send_syn_nr # First not ACKed by other
......@@ -166,14 +166,27 @@ class _Stream:
self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, True, False)))
self.send_syn_nr += 1
self.send_syn_nr = self.__next_seq(self.send_syn_nr)
logging.debug(self.deb + "Setup stream")
@staticmethod
def __next_seq(nr: int) -> int:
return (nr + 1) % max_seq
@staticmethod
def __seq_between(nr: int, lower_bound: int, upper_bound: int) -> bool:
nr %= max_seq
lower_bound %= max_seq
upper_bound %= max_seq
if lower_bound < upper_bound:
return lower_bound <= nr <= upper_bound
else:
return nr >= lower_bound or nr <= upper_bound
def __send_to_be_acked(self, data: bytes, syn_nr: int, flags=_Flags((False, True, False))) -> None:
logging.debug(self.deb +
f"Sending #{syn_nr} {flags}{f' with ACK {self.recv_ack_nr}' if flags.ack else ''} "
f"with {len(data)} bytes")
self.send_syn_nr = max(syn_nr, self.send_syn_nr)
header = _Header(self.local_port, self.remote_port, syn_nr, self.recv_ack_nr if flags.ack else 0,
int(flags), self.local_window_size, len(data))
packet = _TimestampPacket(header, data, time.perf_counter())
......@@ -204,14 +217,15 @@ class _Stream:
if self.expect_syn_ack:
if flags.syn and flags.ack and not flags.fin:
if packet.header.ack_nr != self.send_ack_nr + 1:
if packet.header.ack_nr != self.__next_seq(self.send_ack_nr):
logging.warning(self.deb + "Wrong ACK nr in SYN-ACK")
continue
self.expect_syn_ack = False
self.recv_syn_nr = packet.header.syn_nr
self.recv_ack_nr = packet.header.syn_nr + 1
self.recv_ack_nr = self.__next_seq(packet.header.syn_nr)
self.first_recv_syn_nr = self.recv_syn_nr
self.remote_window_size = packet.header.window
send_ack = True
else:
logging.warning(self.deb + "SYN-ACK expected")
continue
......@@ -224,46 +238,43 @@ class _Stream:
# TODO simultaneous open if self.first_recv_syn_nr is None
continue
if packet.header.syn_nr < self.first_recv_syn_nr:
logging.warning(self.deb + "Too low SYN")
if not self.__seq_between(packet.header.syn_nr, self.recv_ack_nr - self.local_window_size,
self.recv_ack_nr + self.local_window_size - 1):
logging.debug(self.deb + "SYN nr outside of window (spurious retransmission?)")
continue
if packet.header.syn_nr > self.recv_syn_nr + self.local_window_size:
logging.warning(self.deb + "Too high SYN")
continue
if self.__seq_between(packet.header.syn_nr, self.recv_ack_nr + 1,
self.recv_ack_nr + self.local_window_size - 1):
self.recv_syn_nr = packet.header.syn_nr
# TODO if flags.fin:
if flags.ack:
if packet.header.ack_nr > self.send_syn_nr + 1:
logging.warning(self.deb + "Too high ACK")
continue
if packet.header.ack_nr < self.first_send_syn_nr:
logging.warning(self.deb + "Too low ACK")
continue
if packet.header.ack_nr > self.send_ack_nr:
if self.__seq_between(packet.header.ack_nr, self.send_ack_nr + 1, self.send_syn_nr):
self.send_ack_nr = packet.header.ack_nr
# Delete ACKed messages
self.sent_ack_buffer = [p for p in self.sent_ack_buffer if p.header.syn_nr >= packet.header.ack_nr]
self.recv_syn_nr = max(self.recv_syn_nr, packet.header.syn_nr)
self.sent_ack_buffer = [
p for p in self.sent_ack_buffer
if not self.__seq_between(p.header.syn_nr, self.send_ack_nr - self.remote_window_size,
self.send_ack_nr - 1)]
if packet.header.data_length > 0:
if packet.header.syn_nr == self.recv_ack_nr:
# Move run of received packets to receive buffer
self.receive_buffer_lock.acquire()
self.receive_buffer.append(packet.data)
self.recv_ack_nr += 1
self.recv_ack_nr = self.__next_seq(self.recv_ack_nr)
i = 0
while (i < len(self.receive_ack_buffer)
and self.receive_ack_buffer[i].header.syn_nr == self.recv_ack_nr):
self.receive_buffer.append(self.receive_ack_buffer.pop(i).data)
self.recv_ack_nr += 1
self.recv_ack_nr = self.__next_seq(self.recv_ack_nr)
i += 1
self.receive_bytes_available.notify()
self.receive_buffer_lock.release()
elif packet.header.syn_nr > self.recv_ack_nr:
elif self.__seq_between(packet.header.syn_nr, self.recv_ack_nr + 1,
self.recv_ack_nr + self.local_window_size - 1):
# There is a gap, store this packet
insert_index = bisect.bisect_left(self.receive_ack_buffer, packet)
......@@ -304,7 +315,7 @@ class _Stream:
while len(self.send_buffer) > 0 and len(self.sent_ack_buffer) < self.remote_window_size:
data = self.send_buffer.pop(0)
self.__send_to_be_acked(data, self.send_syn_nr)
self.send_syn_nr += 1
self.send_syn_nr = self.__next_seq(self.send_syn_nr)
self.send_buffer_lock.release()
def enqueue_data(self, data: bytes) -> None:
......@@ -312,7 +323,7 @@ class _Stream:
self.send_buffer_lock.acquire()
offset = 0
while offset < len(data):
self.send_buffer.append(data[offset: min(payload_size, len(data) - offset)])
self.send_buffer.append(data[offset: min(offset + payload_size, len(data))])
offset += payload_size
self.send_buffer_lock.release()
......@@ -537,7 +548,7 @@ class Binding:
data[header_size:], addr)
if packet.header.data_length > payload_size:
logging.warning(self.deb + "data_length too large")
packet.data = packet.data[:-(payload_size - packet.header.data_length)]
packet.data = packet.data[:packet.header.data_length]
if not packet.verify_checksum():
logging.warning(self.deb + "Invalid checksum")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment