Commit ba7517e4 authored by StevenWdV's avatar StevenWdV
Browse files

Made sequence numbers cycle to support files over 65KB. Fixed bug that causes...

Made sequence numbers cycle to support files over 65KB. Fixed bug that causes SYN-ACK to not necessarily be ACKed. Fixed bug in enqueue_data.
parent a0b9e219
...@@ -13,6 +13,7 @@ from typing import * ...@@ -13,6 +13,7 @@ from typing import *
header_format = "!HHHHBBHI" header_format = "!HHHHBBHI"
header_size = 16 header_size = 16
payload_size = 1000 payload_size = 1000
max_seq = 0xffFF
class _Flags: class _Flags:
...@@ -117,7 +118,6 @@ class _RemoteInit: ...@@ -117,7 +118,6 @@ class _RemoteInit:
# TODO more timeouts # TODO more timeouts
# TODO fast retransmit # TODO fast retransmit
# TODO cycle sequence numbers
# TODO? dynamic windows size # TODO? dynamic windows size
class _Stream: class _Stream:
def __init__(self, binding: "Binding", local_port: int, def __init__(self, binding: "Binding", local_port: int,
...@@ -144,7 +144,7 @@ class _Stream: ...@@ -144,7 +144,7 @@ class _Stream:
self.deb = f"{self.binding.local_udp_addr}|{self.local_port}->{self.remote_port}: " self.deb = f"{self.binding.local_udp_addr}|{self.local_port}->{self.remote_port}: "
self.first_send_syn_nr = random.randint(0, 0xff) self.first_send_syn_nr = random.randint(0, max_seq)
if remote_init is None: # We initiate the connection if remote_init is None: # We initiate the connection
self.send_syn_nr = self.first_send_syn_nr # Next to be sent self.send_syn_nr = self.first_send_syn_nr # Next to be sent
...@@ -158,7 +158,7 @@ class _Stream: ...@@ -158,7 +158,7 @@ class _Stream:
self.remote_window_size = remote_init.window_size self.remote_window_size = remote_init.window_size
self.recv_syn_nr = remote_init.syn_nr # Last received self.recv_syn_nr = remote_init.syn_nr # Last received
self.recv_ack_nr = remote_init.syn_nr + 1 # First still to be received self.recv_ack_nr = self.__next_seq(remote_init.syn_nr) # First still to be received
self.send_syn_nr = self.first_send_syn_nr # Next to be sent self.send_syn_nr = self.first_send_syn_nr # Next to be sent
self.send_ack_nr = self.send_syn_nr # First not ACKed by other self.send_ack_nr = self.send_syn_nr # First not ACKed by other
...@@ -166,14 +166,27 @@ class _Stream: ...@@ -166,14 +166,27 @@ class _Stream:
self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, True, False))) self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, True, False)))
self.send_syn_nr += 1 self.send_syn_nr = self.__next_seq(self.send_syn_nr)
logging.debug(self.deb + "Setup stream") logging.debug(self.deb + "Setup stream")
@staticmethod
def __next_seq(nr: int) -> int:
return (nr + 1) % max_seq
@staticmethod
def __seq_between(nr: int, lower_bound: int, upper_bound: int) -> bool:
nr %= max_seq
lower_bound %= max_seq
upper_bound %= max_seq
if lower_bound < upper_bound:
return lower_bound <= nr <= upper_bound
else:
return nr >= lower_bound or nr <= upper_bound
def __send_to_be_acked(self, data: bytes, syn_nr: int, flags=_Flags((False, True, False))) -> None: def __send_to_be_acked(self, data: bytes, syn_nr: int, flags=_Flags((False, True, False))) -> None:
logging.debug(self.deb + logging.debug(self.deb +
f"Sending #{syn_nr} {flags}{f' with ACK {self.recv_ack_nr}' if flags.ack else ''} " f"Sending #{syn_nr} {flags}{f' with ACK {self.recv_ack_nr}' if flags.ack else ''} "
f"with {len(data)} bytes") f"with {len(data)} bytes")
self.send_syn_nr = max(syn_nr, self.send_syn_nr)
header = _Header(self.local_port, self.remote_port, syn_nr, self.recv_ack_nr if flags.ack else 0, header = _Header(self.local_port, self.remote_port, syn_nr, self.recv_ack_nr if flags.ack else 0,
int(flags), self.local_window_size, len(data)) int(flags), self.local_window_size, len(data))
packet = _TimestampPacket(header, data, time.perf_counter()) packet = _TimestampPacket(header, data, time.perf_counter())
...@@ -204,14 +217,15 @@ class _Stream: ...@@ -204,14 +217,15 @@ class _Stream:
if self.expect_syn_ack: if self.expect_syn_ack:
if flags.syn and flags.ack and not flags.fin: if flags.syn and flags.ack and not flags.fin:
if packet.header.ack_nr != self.send_ack_nr + 1: if packet.header.ack_nr != self.__next_seq(self.send_ack_nr):
logging.warning(self.deb + "Wrong ACK nr in SYN-ACK") logging.warning(self.deb + "Wrong ACK nr in SYN-ACK")
continue continue
self.expect_syn_ack = False self.expect_syn_ack = False
self.recv_syn_nr = packet.header.syn_nr self.recv_syn_nr = packet.header.syn_nr
self.recv_ack_nr = packet.header.syn_nr + 1 self.recv_ack_nr = self.__next_seq(packet.header.syn_nr)
self.first_recv_syn_nr = self.recv_syn_nr self.first_recv_syn_nr = self.recv_syn_nr
self.remote_window_size = packet.header.window self.remote_window_size = packet.header.window
send_ack = True
else: else:
logging.warning(self.deb + "SYN-ACK expected") logging.warning(self.deb + "SYN-ACK expected")
continue continue
...@@ -224,46 +238,43 @@ class _Stream: ...@@ -224,46 +238,43 @@ class _Stream:
# TODO simultaneous open if self.first_recv_syn_nr is None # TODO simultaneous open if self.first_recv_syn_nr is None
continue continue
if packet.header.syn_nr < self.first_recv_syn_nr: if not self.__seq_between(packet.header.syn_nr, self.recv_ack_nr - self.local_window_size,
logging.warning(self.deb + "Too low SYN") self.recv_ack_nr + self.local_window_size - 1):
logging.debug(self.deb + "SYN nr outside of window (spurious retransmission?)")
continue continue
if packet.header.syn_nr > self.recv_syn_nr + self.local_window_size: if self.__seq_between(packet.header.syn_nr, self.recv_ack_nr + 1,
logging.warning(self.deb + "Too high SYN") self.recv_ack_nr + self.local_window_size - 1):
continue self.recv_syn_nr = packet.header.syn_nr
# TODO if flags.fin: # TODO if flags.fin:
if flags.ack: if flags.ack:
if packet.header.ack_nr > self.send_syn_nr + 1: if self.__seq_between(packet.header.ack_nr, self.send_ack_nr + 1, self.send_syn_nr):
logging.warning(self.deb + "Too high ACK")
continue
if packet.header.ack_nr < self.first_send_syn_nr:
logging.warning(self.deb + "Too low ACK")
continue
if packet.header.ack_nr > self.send_ack_nr:
self.send_ack_nr = packet.header.ack_nr self.send_ack_nr = packet.header.ack_nr
# Delete ACKed messages # Delete ACKed messages
self.sent_ack_buffer = [p for p in self.sent_ack_buffer if p.header.syn_nr >= packet.header.ack_nr] self.sent_ack_buffer = [
p for p in self.sent_ack_buffer
self.recv_syn_nr = max(self.recv_syn_nr, packet.header.syn_nr) if not self.__seq_between(p.header.syn_nr, self.send_ack_nr - self.remote_window_size,
self.send_ack_nr - 1)]
if packet.header.data_length > 0: if packet.header.data_length > 0:
if packet.header.syn_nr == self.recv_ack_nr: if packet.header.syn_nr == self.recv_ack_nr:
# Move run of received packets to receive buffer # Move run of received packets to receive buffer
self.receive_buffer_lock.acquire() self.receive_buffer_lock.acquire()
self.receive_buffer.append(packet.data) self.receive_buffer.append(packet.data)
self.recv_ack_nr += 1 self.recv_ack_nr = self.__next_seq(self.recv_ack_nr)
i = 0 i = 0
while (i < len(self.receive_ack_buffer) while (i < len(self.receive_ack_buffer)
and self.receive_ack_buffer[i].header.syn_nr == self.recv_ack_nr): and self.receive_ack_buffer[i].header.syn_nr == self.recv_ack_nr):
self.receive_buffer.append(self.receive_ack_buffer.pop(i).data) self.receive_buffer.append(self.receive_ack_buffer.pop(i).data)
self.recv_ack_nr += 1 self.recv_ack_nr = self.__next_seq(self.recv_ack_nr)
i += 1 i += 1
self.receive_bytes_available.notify() self.receive_bytes_available.notify()
self.receive_buffer_lock.release() self.receive_buffer_lock.release()
elif packet.header.syn_nr > self.recv_ack_nr: elif self.__seq_between(packet.header.syn_nr, self.recv_ack_nr + 1,
self.recv_ack_nr + self.local_window_size - 1):
# There is a gap, store this packet # There is a gap, store this packet
insert_index = bisect.bisect_left(self.receive_ack_buffer, packet) insert_index = bisect.bisect_left(self.receive_ack_buffer, packet)
...@@ -304,7 +315,7 @@ class _Stream: ...@@ -304,7 +315,7 @@ class _Stream:
while len(self.send_buffer) > 0 and len(self.sent_ack_buffer) < self.remote_window_size: while len(self.send_buffer) > 0 and len(self.sent_ack_buffer) < self.remote_window_size:
data = self.send_buffer.pop(0) data = self.send_buffer.pop(0)
self.__send_to_be_acked(data, self.send_syn_nr) self.__send_to_be_acked(data, self.send_syn_nr)
self.send_syn_nr += 1 self.send_syn_nr = self.__next_seq(self.send_syn_nr)
self.send_buffer_lock.release() self.send_buffer_lock.release()
def enqueue_data(self, data: bytes) -> None: def enqueue_data(self, data: bytes) -> None:
...@@ -312,7 +323,7 @@ class _Stream: ...@@ -312,7 +323,7 @@ class _Stream:
self.send_buffer_lock.acquire() self.send_buffer_lock.acquire()
offset = 0 offset = 0
while offset < len(data): while offset < len(data):
self.send_buffer.append(data[offset: min(payload_size, len(data) - offset)]) self.send_buffer.append(data[offset: min(offset + payload_size, len(data))])
offset += payload_size offset += payload_size
self.send_buffer_lock.release() self.send_buffer_lock.release()
...@@ -537,7 +548,7 @@ class Binding: ...@@ -537,7 +548,7 @@ class Binding:
data[header_size:], addr) data[header_size:], addr)
if packet.header.data_length > payload_size: if packet.header.data_length > payload_size:
logging.warning(self.deb + "data_length too large") logging.warning(self.deb + "data_length too large")
packet.data = packet.data[:-(payload_size - packet.header.data_length)] packet.data = packet.data[:packet.header.data_length]
if not packet.verify_checksum(): if not packet.verify_checksum():
logging.warning(self.deb + "Invalid checksum") logging.warning(self.deb + "Invalid checksum")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment