Commit 1d32f477 authored by StevenWdV's avatar StevenWdV
Browse files

First working prototype

parent ac7b2e9b
#!/usr/local/bin/python3 #!/bin/python3
import argparse import argparse
import logging
import socket import socket
import struct import struct
import sys
import btcp import btcp
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-w", "--window", help="Define bTCP window size", type=int, default=100) parser.add_argument("-w", "--window", help="Define bTCP window size", type=int, default=100)
parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds", type=int, default=100) parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds", type=int, default=100)
parser.add_argument("-i", "--input", help="File to send", default="tmp.file") parser.add_argument("-i", "--input", help="File to send", default="tmp.file")
args = parser.parse_args() args = parser.parse_args()
binding = btcp.Binding(socket.AF_INET, ("", 9002), args.window, args.timeout) binding = btcp.Binding(socket.AF_INET, None, args.window, args.timeout / 1000.)
connection = binding.connect_client(0, 0, ("", 9001)) connection = binding.connect_client(0, 0, ("", 9001))
file = open(args.input, "rb") file = open(args.input, "r+b")
data = file.read() data = file.read()
file.close() file.close()
connection.send(struct.pack("!Q", len(data))) connection.send(struct.pack("!Q", len(data)))
connection.send(data) connection.send(data)
input("press enter to stop\n")
binding.close() binding.close()
#!/usr/local/bin/python3 #!/bin/python3
import argparse import argparse
import logging
import socket import socket
import struct import struct
import sys
import btcp import btcp
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# Handle arguments # Handle arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-w", "--window", help="Define bTCP window size", type=int, default=100) parser.add_argument("-w", "--window", help="Define bTCP window size", type=int, default=100)
...@@ -12,7 +16,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds ...@@ -12,7 +16,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds
parser.add_argument("-o", "--output", help="Where to store file", default="tmp.file") parser.add_argument("-o", "--output", help="Where to store file", default="tmp.file")
args = parser.parse_args() args = parser.parse_args()
binding = btcp.Binding(socket.AF_INET, ("", 9001), args.window, args.timeout) binding = btcp.Binding(socket.AF_INET, ("", 9001), args.window, args.timeout / 1000.)
server = binding.bind_server(0) server = binding.bind_server(0)
server.start_listen(1) server.start_listen(1)
...@@ -20,11 +24,9 @@ connection = server.accept() ...@@ -20,11 +24,9 @@ connection = server.accept()
file_size: int = struct.unpack("!Q", connection.receive(8))[0] file_size: int = struct.unpack("!Q", connection.receive(8))[0]
file = open(args.output, "wb") file = open(args.output, "w+b")
while file_size > 0: data = connection.receive(file_size)
data = connection.receive_all() file.write(data)
file.write(data)
file_size -= len(data)
file.close() file.close()
binding.close() binding.close()
import binascii import binascii
import bisect import bisect
import copy import copy
import logging
import random import random
import select import select
import socket import socket
...@@ -26,9 +27,19 @@ class _Flags: ...@@ -26,9 +27,19 @@ class _Flags:
else: else:
self.syn, self.ack, self.fin = flags self.syn, self.ack, self.fin = flags
def to_int(self) -> int: def __int__(self) -> int:
return self.syn << 2 | self.ack << 1 | self.fin return self.syn << 2 | self.ack << 1 | self.fin
def __str__(self) -> str:
strs: List[str] = []
if self.syn:
strs.append("SYN")
if self.ack:
strs.append("ACK")
if self.fin:
strs.append("FIN")
return "-".join(strs)
class _Header: class _Header:
def __init__(self, src_port: int, dest_port: int, syn_nr: int, ack_nr: int, flags: int, def __init__(self, src_port: int, dest_port: int, syn_nr: int, ack_nr: int, flags: int,
...@@ -98,6 +109,7 @@ class _AddrPacket(_Packet): ...@@ -98,6 +109,7 @@ class _AddrPacket(_Packet):
self.remote_udp_addr = remote_udp_addr self.remote_udp_addr = remote_udp_addr
# Info for if the connection was not initiated from our side
class _RemoteInit: class _RemoteInit:
def __init__(self, syn_nr: int, window_size: int): def __init__(self, syn_nr: int, window_size: int):
self.syn_nr = syn_nr self.syn_nr = syn_nr
...@@ -107,13 +119,12 @@ class _RemoteInit: ...@@ -107,13 +119,12 @@ class _RemoteInit:
class _Stream: class _Stream:
def __init__(self, binding: "Binding", local_port: int, def __init__(self, binding: "Binding", local_port: int,
remote_port: int, remote_udp_addr: Any, remote_port: int, remote_udp_addr: Any,
local_window_size: int, timeout: int, local_window_size: int, timeout: float,
remote_init: Optional[_RemoteInit] = None): remote_init: Optional[_RemoteInit] = None):
self.binding = binding self.binding = binding
self.local_port = local_port self.local_port = local_port
self.remote_port = remote_port self.remote_port = remote_port
self.local_window_size = local_window_size self.local_window_size = local_window_size
self.remote_window_size = remote_init.window_size
self.remote_udp_addr = remote_udp_addr self.remote_udp_addr = remote_udp_addr
self.timeout = timeout self.timeout = timeout
self.expect_syn_ack = remote_init is None self.expect_syn_ack = remote_init is None
...@@ -128,35 +139,48 @@ class _Stream: ...@@ -128,35 +139,48 @@ class _Stream:
self.receive_buffer_lock = threading.Lock() self.receive_buffer_lock = threading.Lock()
self.receive_bytes_available = threading.Condition(self.receive_buffer_lock) self.receive_bytes_available = threading.Condition(self.receive_buffer_lock)
if remote_init is None: self.deb = f"{self.binding.local_udp_addr}|{self.local_port}->{self.remote_port}: "
self.send_syn_nr = random.randint(0, 0xffFF) # Next to be sent
self.first_send_syn_nr = random.randint(0, 0xffFF)
if remote_init is None: # We initiate the connection
self.send_syn_nr = self.first_send_syn_nr # Next to be sent
self.send_ack_nr = self.send_syn_nr # First not ACKed by other self.send_ack_nr = self.send_syn_nr # First not ACKed by other
self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, False, False))) self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, False, False)))
else: else:
self.remote_window_size = remote_init.window_size
self.recv_syn_nr = remote_init.syn_nr # Last received self.recv_syn_nr = remote_init.syn_nr # Last received
self.recv_ack_nr = remote_init.syn_nr + 1 # First still to be received self.recv_ack_nr = remote_init.syn_nr + 1 # First still to be received
self.send_syn_nr = random.randint(0, 0xffFF) # Next to be sent self.send_syn_nr = self.first_send_syn_nr # Next to be sent
self.send_ack_nr = self.send_syn_nr # First not ACKed by other self.send_ack_nr = self.send_syn_nr # First not ACKed by other
self.first_recv_syn_nr = self.recv_syn_nr
self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, True, False))) self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, True, False)))
self.send_syn_nr += 1 self.send_syn_nr += 1
logging.debug(self.deb + "Setup stream")
def __send_to_be_acked(self, data: bytes, syn_nr: int, flags=_Flags((False, True, False))) -> None: def __send_to_be_acked(self, data: bytes, syn_nr: int, flags=_Flags((False, True, False))) -> None:
logging.debug(self.deb +
f"Sending #{syn_nr} {flags}{f' with ACK {self.recv_ack_nr}' if flags.ack else ''} "
f"with {len(data)} bytes")
self.send_syn_nr = max(syn_nr, self.send_syn_nr) self.send_syn_nr = max(syn_nr, self.send_syn_nr)
header = _Header(self.local_port, self.remote_port, syn_nr, self.recv_ack_nr if flags.ack else 0, header = _Header(self.local_port, self.remote_port, syn_nr, self.recv_ack_nr if flags.ack else 0,
flags.to_int(), self.local_window_size, len(data)) int(flags), self.local_window_size, len(data))
packet = _TimestampPacket(header, data, time.clock()) packet = _TimestampPacket(header, data, time.perf_counter())
packet.set_checksum() packet.set_checksum()
self.sent_ack_buffer.append(packet) self.sent_ack_buffer.append(packet)
# noinspection PyProtectedMember # noinspection PyProtectedMember
self.binding._send(packet, self.remote_udp_addr) self.binding._send(packet, self.remote_udp_addr)
def __send_ack(self) -> None: def __send_ack(self) -> None:
logging.debug(self.deb + f"Sending ACK {self.recv_ack_nr}")
ack_header = _Header(self.local_port, self.remote_port, self.send_syn_nr, self.recv_ack_nr, ack_header = _Header(self.local_port, self.remote_port, self.send_syn_nr, self.recv_ack_nr,
_Flags((False, True, False)).to_int(), self.local_window_size, 0) int(_Flags((False, True, False))), self.local_window_size, 0)
ack_packet = _Packet(ack_header, b"") ack_packet = _Packet(ack_header, b"")
ack_packet.set_checksum() ack_packet.set_checksum()
...@@ -164,41 +188,49 @@ class _Stream: ...@@ -164,41 +188,49 @@ class _Stream:
self.binding._send(ack_packet, self.remote_udp_addr) self.binding._send(ack_packet, self.remote_udp_addr)
def pass_received_msgs(self, packets: List[_AddrPacket]) -> None: def pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
logging.debug(self.deb + f"Received {len(packets)} messages")
send_ack = False send_ack = False
for packet in packets: for packet in packets:
if packet.remote_udp_addr != self.remote_udp_addr:
print("Wrong remote UDP address")
continue
flags = packet.header.flags_obj() flags = packet.header.flags_obj()
logging.debug(self.deb + f"Received #{packet.header.syn_nr} {flags}"
f"{f' with ACK {packet.header.ack_nr}' if flags.ack else ''} "
f"with {len(packet.data)} bytes")
if self.expect_syn_ack: if self.expect_syn_ack:
if flags.syn and flags.ack and not flags.fin: if flags.syn and flags.ack and not flags.fin:
if packet.header.ack_nr != self.send_ack_nr: if packet.header.ack_nr != self.send_ack_nr + 1:
print("Wrong ACK nr in SYN-ACK") logging.warning(self.deb + "Wrong ACK nr in SYN-ACK")
continue continue
self.expect_syn_ack = False self.expect_syn_ack = False
self.recv_syn_nr = packet.header.syn_nr self.recv_syn_nr = packet.header.syn_nr
self.recv_ack_nr = packet.header.syn_nr + 1 self.recv_ack_nr = packet.header.syn_nr + 1
self.first_recv_syn_nr = self.recv_syn_nr
self.remote_window_size = packet.header.window self.remote_window_size = packet.header.window
else: else:
print("SYN-ACK expected") logging.warning(self.deb + "SYN-ACK expected")
continue continue
elif flags.syn: elif flags.syn:
print("Unexpected SYN") if packet.header.syn_nr == self.first_recv_syn_nr:
# TODO simultaneous open? logging.debug(self.deb + "Spurious SYN")
continue
else:
logging.warning(self.deb + "Unexpected SYN")
# TODO simultaneous open
continue continue
if packet.header.syn_nr > self.recv_syn_nr + self.local_window_size: if packet.header.syn_nr > self.recv_syn_nr + self.local_window_size:
print("Too high SYN") logging.warning(self.deb + "Too high SYN")
continue continue
# TODO if flags.fin: # TODO if flags.fin:
if flags.ack: if flags.ack:
if packet.header.ack_nr > self.send_syn_nr + 1: if packet.header.ack_nr > self.send_syn_nr + 1:
print("Too high ACK") logging.warning(self.deb + "Too high ACK")
continue
if packet.header.ack_nr < self.first_send_syn_nr:
logging.warning(self.deb + "Too low ACK")
continue continue
if packet.header.ack_nr > self.send_ack_nr: if packet.header.ack_nr > self.send_ack_nr:
self.send_ack_nr = packet.header.ack_nr self.send_ack_nr = packet.header.ack_nr
...@@ -225,9 +257,12 @@ class _Stream: ...@@ -225,9 +257,12 @@ class _Stream:
elif packet.header.syn_nr > self.recv_ack_nr: elif packet.header.syn_nr > self.recv_ack_nr:
# There is a gap, store this packet # There is a gap, store this packet
insert_index = bisect.bisect_left(self.receive_ack_buffer, packet) insert_index = bisect.bisect_left(self.receive_ack_buffer, packet)
# Check for duplicate packet
if (insert_index + 1 > len(self.receive_ack_buffer) if (insert_index + 1 < len(self.receive_ack_buffer)
or self.receive_ack_buffer[insert_index + 1].header.syn_nr != packet.header.syn_nr): and self.receive_ack_buffer[insert_index + 1].header.syn_nr == packet.header.syn_nr):
assert self.receive_ack_buffer[insert_index + 1].data == packet.data
logging.debug(self.deb + "Spurious retransmission")
else:
self.receive_ack_buffer.insert(insert_index, packet) self.receive_ack_buffer.insert(insert_index, packet)
send_ack = True send_ack = True
...@@ -237,27 +272,34 @@ class _Stream: ...@@ -237,27 +272,34 @@ class _Stream:
self.send_buffer_lock.acquire() self.send_buffer_lock.acquire()
data_to_send = len(self.send_buffer) > 0 data_to_send = len(self.send_buffer) > 0
self.send_buffer_lock.release() self.send_buffer_lock.release()
if data_to_send: if not data_to_send:
self.__send_ack() self.__send_ack()
def poll_sender(self) -> None: def poll_sender(self) -> None:
lost = [p for p in self.sent_ack_buffer if time.perf_counter() - p.timestamp >= self.timeout]
self.sent_ack_buffer = [p for p in self.sent_ack_buffer if time.perf_counter() - p.timestamp < self.timeout]
if len(lost) > 0:
logging.debug(self.deb + f"{len(lost)} lost messages")
for p in lost:
p.timestamp = time.perf_counter()
self.__send_to_be_acked(p.data, p.header.syn_nr, p.header.flags_obj())
if self.expect_syn_ack: if self.expect_syn_ack:
return return
lost = [p for p in self.sent_ack_buffer if time.clock() - p.timestamp >= self.timeout] if len(self.send_buffer) > 0:
for p in lost: logging.debug(self.deb + f"{len(self.send_buffer)} unsent messages")
p.timestamp = time.clock()
self.__send_to_be_acked(p.data, p.header.syn_nr, p.header.flags_obj())
if len(self.sent_ack_buffer) < self.remote_window_size: if len(self.sent_ack_buffer) < self.remote_window_size:
self.send_buffer_lock.acquire() self.send_buffer_lock.acquire()
while len(self.sent_ack_buffer) < self.remote_window_size: while len(self.send_buffer) > 0 and len(self.sent_ack_buffer) < self.remote_window_size:
data = self.send_buffer.pop(0) data = self.send_buffer.pop(0)
self.__send_to_be_acked(data + bytes(payload_size - len(data)), self.send_syn_nr) self.__send_to_be_acked(data, self.send_syn_nr)
self.send_syn_nr += 1 self.send_syn_nr += 1
self.send_buffer_lock.release() self.send_buffer_lock.release()
def enqueue_data(self, data: bytes) -> None: def enqueue_data(self, data: bytes) -> None:
logging.debug(self.deb + f"Enqueuing {len(data)} bytes")
self.send_buffer_lock.acquire() self.send_buffer_lock.acquire()
offset = 0 offset = 0
while offset < len(data): while offset < len(data):
...@@ -268,11 +310,12 @@ class _Stream: ...@@ -268,11 +310,12 @@ class _Stream:
def __at_least_received(self, count: int) -> bool: def __at_least_received(self, count: int) -> bool:
i = 0 i = 0
while i < len(self.receive_buffer) and count > 0: while i < len(self.receive_buffer) and count > 0:
count += len(self.receive_buffer[i]) count -= len(self.receive_buffer[i])
i += 1 i += 1
return count <= 0 return count <= 0
def dequeue_data(self, count: int) -> bytes: def dequeue_data(self, count: int) -> bytes:
logging.debug(self.deb + f"Dequeuing {count} bytes...")
self.receive_buffer_lock.acquire() self.receive_buffer_lock.acquire()
self.receive_bytes_available.wait_for(lambda: self.__at_least_received(count)) self.receive_bytes_available.wait_for(lambda: self.__at_least_received(count))
...@@ -281,10 +324,12 @@ class _Stream: ...@@ -281,10 +324,12 @@ class _Stream:
data += self.receive_buffer.pop(0) data += self.receive_buffer.pop(0)
if len(data) < count: if len(data) < count:
data += self.receive_buffer[0][:len(data) - count] bytes_short = count - len(data)
self.receive_buffer[0] = self.receive_buffer[0][len(data) - count:] data += self.receive_buffer[0][:bytes_short]
self.receive_buffer[0] = self.receive_buffer[0][bytes_short:]
self.receive_buffer_lock.release() self.receive_buffer_lock.release()
logging.debug(self.deb + "Dequeued")
return data return data
def dequeue_is_available(self, count: int) -> bool: def dequeue_is_available(self, count: int) -> bool:
...@@ -294,14 +339,17 @@ class _Stream: ...@@ -294,14 +339,17 @@ class _Stream:
return available return available
def dequeue_all_data(self) -> bytes: def dequeue_all_data(self) -> bytes:
logging.debug(self.deb + "Dequeuing all data...")
self.receive_buffer_lock.acquire() self.receive_buffer_lock.acquire()
data = b"" data = b""
while len(self.receive_buffer) > 0: while len(self.receive_buffer) > 0:
data += self.receive_buffer.pop(0) data += self.receive_buffer.pop(0)
self.receive_buffer_lock.release() self.receive_buffer_lock.release()
logging.debug(self.deb + f"Dequeued {len(data)} bytes")
return data return data
def close(self) -> None: def close(self) -> None:
logging.debug(self.deb + "Close")
# TODO FIN # TODO FIN
pass pass
...@@ -310,6 +358,11 @@ class _Connection: ...@@ -310,6 +358,11 @@ class _Connection:
def __init__(self, parent: Union["_Server", "Binding"], stream: _Stream): def __init__(self, parent: Union["_Server", "Binding"], stream: _Stream):
self.parent = parent self.parent = parent
self.stream = stream self.stream = stream
self.deb = f"{self.__parent_str()}|{self.stream.local_port}->{self.stream.remote_port}: "
logging.debug(self.deb + "Setup connection")
def __parent_str(self) -> str:
return self.parent.local_udp_addr if type(self.parent) is Binding else self.parent.binding.local_udp_addr
def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None: def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
self.stream.pass_received_msgs(packets) self.stream.pass_received_msgs(packets)
...@@ -330,6 +383,7 @@ class _Connection: ...@@ -330,6 +383,7 @@ class _Connection:
return self.stream.dequeue_all_data() return self.stream.dequeue_all_data()
def close(self) -> None: def close(self) -> None:
logging.debug(self.deb + "Close")
if type(self.parent) is _Server: if type(self.parent) is _Server:
# noinspection PyProtectedMember # noinspection PyProtectedMember
self.parent._remove_stream(self.stream.remote_port) self.parent._remove_stream(self.stream.remote_port)
...@@ -337,20 +391,27 @@ class _Connection: ...@@ -337,20 +391,27 @@ class _Connection:
# noinspection PyProtectedMember # noinspection PyProtectedMember
self.parent._remove_socket(self.stream.remote_port) self.parent._remove_socket(self.stream.remote_port)
self.stream.close() self.stream.close()
logging.debug(self.deb + "Closed")
class _Server: class _Server:
def __init__(self, binding: "Binding", local_port: int): def __init__(self, binding: "Binding", local_port: int):
self.binding = binding self.binding = binding
self.local_port = local_port self.local_port = local_port
self.streams: Dict[int, _Stream] = {} self.streams: Dict[int, _Stream] = {} # Remote port -> _Stream
self.streams_lock = threading.Lock() self.streams_lock = threading.Lock()
self.backlog = 0 self.backlog = 0
self.backlog_connections: List[_Connection] = [] self.backlog_connections: List[_Connection] = []
self.backlog_lock = threading.Lock() self.backlog_lock = threading.Lock()
self.backlog_available = threading.Condition(self.backlog_lock)
self.deb = f"{self.binding.local_udp_addr}|{self.local_port}: "
logging.debug(self.deb + "Setup server")
def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None: def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
logging.debug(self.deb + f"Received {len(packets)} messages")
packet_batches: Dict[int, List[_AddrPacket]] = {} packet_batches: Dict[int, List[_AddrPacket]] = {}
orphan_packets: List[_AddrPacket] = [] orphan_packets: List[_AddrPacket] = []
...@@ -358,8 +419,7 @@ class _Server: ...@@ -358,8 +419,7 @@ class _Server:
for packet in packets: for packet in packets:
if packet.header.src_port in self.streams: if packet.header.src_port in self.streams:
batch = packet_batches[packet.header.src_port] = [] packet_batches.setdefault(packet.header.src_port, []).append(packet)
batch.append(packet)
else: else:
orphan_packets.append(packet) orphan_packets.append(packet)
...@@ -379,12 +439,13 @@ class _Server: ...@@ -379,12 +439,13 @@ class _Server:
_RemoteInit(packet.header.syn_nr, packet.header.window)) _RemoteInit(packet.header.syn_nr, packet.header.window))
self.backlog_connections.append(_Connection(self, stream)) self.backlog_connections.append(_Connection(self, stream))
self.backlog -= 1 self.backlog -= 1
self.backlog_available.notify()
self.backlog_lock.release() self.backlog_lock.release()
else: else:
self.backlog_lock.release() self.backlog_lock.release()
print("Backlog full") logging.warning(self.deb + "Backlog full")
else: else:
print("Unknown connection") logging.warning(self.deb + "Unknown connection")
self.streams_lock.release() self.streams_lock.release()
...@@ -395,47 +456,58 @@ class _Server: ...@@ -395,47 +456,58 @@ class _Server:
self.streams_lock.release() self.streams_lock.release()
def _remove_stream(self, remote_port: int) -> None: def _remove_stream(self, remote_port: int) -> None:
logging.debug(self.deb + "Remove stream")
self.streams_lock.acquire() self.streams_lock.acquire()