Commit 1d32f477 authored by StevenWdV's avatar StevenWdV

First working prototype

parent ac7b2e9b
#!/usr/local/bin/python3
#!/bin/python3
import argparse
import logging
import socket
import struct
import sys
import btcp
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument("-w", "--window", help="Define bTCP window size", type=int, default=100)
parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds", type=int, default=100)
parser.add_argument("-i", "--input", help="File to send", default="tmp.file")
args = parser.parse_args()
binding = btcp.Binding(socket.AF_INET, ("", 9002), args.window, args.timeout)
binding = btcp.Binding(socket.AF_INET, None, args.window, args.timeout / 1000.)
connection = binding.connect_client(0, 0, ("", 9001))
file = open(args.input, "rb")
file = open(args.input, "r+b")
data = file.read()
file.close()
connection.send(struct.pack("!Q", len(data)))
connection.send(data)
input("press enter to stop\n")
binding.close()
#!/usr/local/bin/python3
#!/bin/python3
import argparse
import logging
import socket
import struct
import sys
import btcp
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# Handle arguments
parser = argparse.ArgumentParser()
parser.add_argument("-w", "--window", help="Define bTCP window size", type=int, default=100)
......@@ -12,7 +16,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds
parser.add_argument("-o", "--output", help="Where to store file", default="tmp.file")
args = parser.parse_args()
binding = btcp.Binding(socket.AF_INET, ("", 9001), args.window, args.timeout)
binding = btcp.Binding(socket.AF_INET, ("", 9001), args.window, args.timeout / 1000.)
server = binding.bind_server(0)
server.start_listen(1)
......@@ -20,11 +24,9 @@ connection = server.accept()
file_size: int = struct.unpack("!Q", connection.receive(8))[0]
file = open(args.output, "wb")
while file_size > 0:
data = connection.receive_all()
file.write(data)
file_size -= len(data)
file = open(args.output, "w+b")
data = connection.receive(file_size)
file.write(data)
file.close()
binding.close()
import binascii
import bisect
import copy
import logging
import random
import select
import socket
......@@ -26,9 +27,19 @@ class _Flags:
else:
self.syn, self.ack, self.fin = flags
def to_int(self) -> int:
def __int__(self) -> int:
return self.syn << 2 | self.ack << 1 | self.fin
def __str__(self) -> str:
strs: List[str] = []
if self.syn:
strs.append("SYN")
if self.ack:
strs.append("ACK")
if self.fin:
strs.append("FIN")
return "-".join(strs)
class _Header:
def __init__(self, src_port: int, dest_port: int, syn_nr: int, ack_nr: int, flags: int,
......@@ -98,6 +109,7 @@ class _AddrPacket(_Packet):
self.remote_udp_addr = remote_udp_addr
# Info for if the connection was not initiated from our side
class _RemoteInit:
def __init__(self, syn_nr: int, window_size: int):
self.syn_nr = syn_nr
......@@ -107,13 +119,12 @@ class _RemoteInit:
class _Stream:
def __init__(self, binding: "Binding", local_port: int,
remote_port: int, remote_udp_addr: Any,
local_window_size: int, timeout: int,
local_window_size: int, timeout: float,
remote_init: Optional[_RemoteInit] = None):
self.binding = binding
self.local_port = local_port
self.remote_port = remote_port
self.local_window_size = local_window_size
self.remote_window_size = remote_init.window_size
self.remote_udp_addr = remote_udp_addr
self.timeout = timeout
self.expect_syn_ack = remote_init is None
......@@ -128,35 +139,48 @@ class _Stream:
self.receive_buffer_lock = threading.Lock()
self.receive_bytes_available = threading.Condition(self.receive_buffer_lock)
if remote_init is None:
self.send_syn_nr = random.randint(0, 0xffFF) # Next to be sent
self.deb = f"{self.binding.local_udp_addr}|{self.local_port}->{self.remote_port}: "
self.first_send_syn_nr = random.randint(0, 0xffFF)
if remote_init is None: # We initiate the connection
self.send_syn_nr = self.first_send_syn_nr # Next to be sent
self.send_ack_nr = self.send_syn_nr # First not ACKed by other
self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, False, False)))
else:
self.remote_window_size = remote_init.window_size
self.recv_syn_nr = remote_init.syn_nr # Last received
self.recv_ack_nr = remote_init.syn_nr + 1 # First still to be received
self.send_syn_nr = random.randint(0, 0xffFF) # Next to be sent
self.send_syn_nr = self.first_send_syn_nr # Next to be sent
self.send_ack_nr = self.send_syn_nr # First not ACKed by other
self.first_recv_syn_nr = self.recv_syn_nr
self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, True, False)))
self.send_syn_nr += 1
logging.debug(self.deb + "Setup stream")
def __send_to_be_acked(self, data: bytes, syn_nr: int, flags=_Flags((False, True, False))) -> None:
logging.debug(self.deb +
f"Sending #{syn_nr} {flags}{f' with ACK {self.recv_ack_nr}' if flags.ack else ''} "
f"with {len(data)} bytes")
self.send_syn_nr = max(syn_nr, self.send_syn_nr)
header = _Header(self.local_port, self.remote_port, syn_nr, self.recv_ack_nr if flags.ack else 0,
flags.to_int(), self.local_window_size, len(data))
packet = _TimestampPacket(header, data, time.clock())
int(flags), self.local_window_size, len(data))
packet = _TimestampPacket(header, data, time.perf_counter())
packet.set_checksum()
self.sent_ack_buffer.append(packet)
# noinspection PyProtectedMember
self.binding._send(packet, self.remote_udp_addr)
def __send_ack(self) -> None:
logging.debug(self.deb + f"Sending ACK {self.recv_ack_nr}")
ack_header = _Header(self.local_port, self.remote_port, self.send_syn_nr, self.recv_ack_nr,
_Flags((False, True, False)).to_int(), self.local_window_size, 0)
int(_Flags((False, True, False))), self.local_window_size, 0)
ack_packet = _Packet(ack_header, b"")
ack_packet.set_checksum()
......@@ -164,41 +188,49 @@ class _Stream:
self.binding._send(ack_packet, self.remote_udp_addr)
def pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
logging.debug(self.deb + f"Received {len(packets)} messages")
send_ack = False
for packet in packets:
if packet.remote_udp_addr != self.remote_udp_addr:
print("Wrong remote UDP address")
continue
flags = packet.header.flags_obj()
logging.debug(self.deb + f"Received #{packet.header.syn_nr} {flags}"
f"{f' with ACK {packet.header.ack_nr}' if flags.ack else ''} "
f"with {len(packet.data)} bytes")
if self.expect_syn_ack:
if flags.syn and flags.ack and not flags.fin:
if packet.header.ack_nr != self.send_ack_nr:
print("Wrong ACK nr in SYN-ACK")
if packet.header.ack_nr != self.send_ack_nr + 1:
logging.warning(self.deb + "Wrong ACK nr in SYN-ACK")
continue
self.expect_syn_ack = False
self.recv_syn_nr = packet.header.syn_nr
self.recv_ack_nr = packet.header.syn_nr + 1
self.first_recv_syn_nr = self.recv_syn_nr
self.remote_window_size = packet.header.window
else:
print("SYN-ACK expected")
logging.warning(self.deb + "SYN-ACK expected")
continue
elif flags.syn:
print("Unexpected SYN")
# TODO simultaneous open?
continue
if packet.header.syn_nr == self.first_recv_syn_nr:
logging.debug(self.deb + "Spurious SYN")
continue
else:
logging.warning(self.deb + "Unexpected SYN")
# TODO simultaneous open
continue
if packet.header.syn_nr > self.recv_syn_nr + self.local_window_size:
print("Too high SYN")
logging.warning(self.deb + "Too high SYN")
continue
# TODO if flags.fin:
if flags.ack:
if packet.header.ack_nr > self.send_syn_nr + 1:
print("Too high ACK")
logging.warning(self.deb + "Too high ACK")
continue
if packet.header.ack_nr < self.first_send_syn_nr:
logging.warning(self.deb + "Too low ACK")
continue
if packet.header.ack_nr > self.send_ack_nr:
self.send_ack_nr = packet.header.ack_nr
......@@ -225,9 +257,12 @@ class _Stream:
elif packet.header.syn_nr > self.recv_ack_nr:
# There is a gap, store this packet
insert_index = bisect.bisect_left(self.receive_ack_buffer, packet)
# Check for duplicate packet
if (insert_index + 1 > len(self.receive_ack_buffer)
or self.receive_ack_buffer[insert_index + 1].header.syn_nr != packet.header.syn_nr):
if (insert_index + 1 < len(self.receive_ack_buffer)
and self.receive_ack_buffer[insert_index + 1].header.syn_nr == packet.header.syn_nr):
assert self.receive_ack_buffer[insert_index + 1].data == packet.data
logging.debug(self.deb + "Spurious retransmission")
else:
self.receive_ack_buffer.insert(insert_index, packet)
send_ack = True
......@@ -237,27 +272,34 @@ class _Stream:
self.send_buffer_lock.acquire()
data_to_send = len(self.send_buffer) > 0
self.send_buffer_lock.release()
if data_to_send:
if not data_to_send:
self.__send_ack()
def poll_sender(self) -> None:
lost = [p for p in self.sent_ack_buffer if time.perf_counter() - p.timestamp >= self.timeout]
self.sent_ack_buffer = [p for p in self.sent_ack_buffer if time.perf_counter() - p.timestamp < self.timeout]
if len(lost) > 0:
logging.debug(self.deb + f"{len(lost)} lost messages")
for p in lost:
p.timestamp = time.perf_counter()
self.__send_to_be_acked(p.data, p.header.syn_nr, p.header.flags_obj())
if self.expect_syn_ack:
return
lost = [p for p in self.sent_ack_buffer if time.clock() - p.timestamp >= self.timeout]
for p in lost:
p.timestamp = time.clock()
self.__send_to_be_acked(p.data, p.header.syn_nr, p.header.flags_obj())
if len(self.send_buffer) > 0:
logging.debug(self.deb + f"{len(self.send_buffer)} unsent messages")
if len(self.sent_ack_buffer) < self.remote_window_size:
self.send_buffer_lock.acquire()
while len(self.sent_ack_buffer) < self.remote_window_size:
while len(self.send_buffer) > 0 and len(self.sent_ack_buffer) < self.remote_window_size:
data = self.send_buffer.pop(0)
self.__send_to_be_acked(data + bytes(payload_size - len(data)), self.send_syn_nr)
self.__send_to_be_acked(data, self.send_syn_nr)
self.send_syn_nr += 1
self.send_buffer_lock.release()
def enqueue_data(self, data: bytes) -> None:
logging.debug(self.deb + f"Enqueuing {len(data)} bytes")
self.send_buffer_lock.acquire()
offset = 0
while offset < len(data):
......@@ -268,11 +310,12 @@ class _Stream:
def __at_least_received(self, count: int) -> bool:
i = 0
while i < len(self.receive_buffer) and count > 0:
count += len(self.receive_buffer[i])
count -= len(self.receive_buffer[i])
i += 1
return count <= 0
def dequeue_data(self, count: int) -> bytes:
logging.debug(self.deb + f"Dequeuing {count} bytes...")
self.receive_buffer_lock.acquire()
self.receive_bytes_available.wait_for(lambda: self.__at_least_received(count))
......@@ -281,10 +324,12 @@ class _Stream:
data += self.receive_buffer.pop(0)
if len(data) < count:
data += self.receive_buffer[0][:len(data) - count]
self.receive_buffer[0] = self.receive_buffer[0][len(data) - count:]
bytes_short = count - len(data)
data += self.receive_buffer[0][:bytes_short]
self.receive_buffer[0] = self.receive_buffer[0][bytes_short:]
self.receive_buffer_lock.release()
logging.debug(self.deb + "Dequeued")
return data
def dequeue_is_available(self, count: int) -> bool:
......@@ -294,14 +339,17 @@ class _Stream:
return available
def dequeue_all_data(self) -> bytes:
logging.debug(self.deb + "Dequeuing all data...")
self.receive_buffer_lock.acquire()
data = b""
while len(self.receive_buffer) > 0:
data += self.receive_buffer.pop(0)
self.receive_buffer_lock.release()
logging.debug(self.deb + f"Dequeued {len(data)} bytes")
return data
def close(self) -> None:
logging.debug(self.deb + "Close")
# TODO FIN
pass
......@@ -310,6 +358,11 @@ class _Connection:
def __init__(self, parent: Union["_Server", "Binding"], stream: _Stream):
self.parent = parent
self.stream = stream
self.deb = f"{self.__parent_str()}|{self.stream.local_port}->{self.stream.remote_port}: "
logging.debug(self.deb + "Setup connection")
def __parent_str(self) -> str:
return self.parent.local_udp_addr if type(self.parent) is Binding else self.parent.binding.local_udp_addr
def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
self.stream.pass_received_msgs(packets)
......@@ -330,6 +383,7 @@ class _Connection:
return self.stream.dequeue_all_data()
def close(self) -> None:
logging.debug(self.deb + "Close")
if type(self.parent) is _Server:
# noinspection PyProtectedMember
self.parent._remove_stream(self.stream.remote_port)
......@@ -337,20 +391,27 @@ class _Connection:
# noinspection PyProtectedMember
self.parent._remove_socket(self.stream.remote_port)
self.stream.close()
logging.debug(self.deb + "Closed")
class _Server:
def __init__(self, binding: "Binding", local_port: int):
self.binding = binding
self.local_port = local_port
self.streams: Dict[int, _Stream] = {}
self.streams: Dict[int, _Stream] = {} # Remote port -> _Stream
self.streams_lock = threading.Lock()
self.backlog = 0
self.backlog_connections: List[_Connection] = []
self.backlog_lock = threading.Lock()
self.backlog_available = threading.Condition(self.backlog_lock)
self.deb = f"{self.binding.local_udp_addr}|{self.local_port}: "
logging.debug(self.deb + "Setup server")
def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
logging.debug(self.deb + f"Received {len(packets)} messages")
packet_batches: Dict[int, List[_AddrPacket]] = {}
orphan_packets: List[_AddrPacket] = []
......@@ -358,8 +419,7 @@ class _Server:
for packet in packets:
if packet.header.src_port in self.streams:
batch = packet_batches[packet.header.src_port] = []
batch.append(packet)
packet_batches.setdefault(packet.header.src_port, []).append(packet)
else:
orphan_packets.append(packet)
......@@ -379,12 +439,13 @@ class _Server:
_RemoteInit(packet.header.syn_nr, packet.header.window))
self.backlog_connections.append(_Connection(self, stream))
self.backlog -= 1
self.backlog_available.notify()
self.backlog_lock.release()
else:
self.backlog_lock.release()
print("Backlog full")
logging.warning(self.deb + "Backlog full")
else:
print("Unknown connection")
logging.warning(self.deb + "Unknown connection")
self.streams_lock.release()
......@@ -395,47 +456,58 @@ class _Server:
self.streams_lock.release()
def _remove_stream(self, remote_port: int) -> None:
logging.debug(self.deb + "Remove stream")
self.streams_lock.acquire()
del self.streams[remote_port]
self.streams_lock.release()
def start_listen(self, backlog: int) -> None:
logging.debug(self.deb + "Start listen")
self.backlog_lock.acquire()
self.backlog = backlog
self.backlog_lock.release()
def accept(self) -> _Connection:
logging.debug(self.deb + "Accepting...")
self.backlog_lock.acquire()
self.backlog_available.wait_for(lambda: len(self.backlog_connections) > 0)
connection = self.backlog_connections.pop(0)
self.backlog_lock.release()
logging.debug(self.deb + "Accepted")
return connection
def close(self) -> None:
logging.debug(self.deb + "Closing...")
# noinspection PyProtectedMember
self.binding._remove_socket(self.local_port)
for stream in self.streams.values():
stream.close()
while len(self.streams) > 0:
self.streams.pop(0).close()
logging.debug(self.deb + "Closed")
class Binding:
def __init__(self, protocol: int, local_udp_addr: Optional[Any],
window_size: int, timeout: int):
window_size: int, timeout_sec: float):
self.local_udp_addr = local_udp_addr
self.window_size = window_size
self.timeout = timeout
self.timeout = timeout_sec
self.sock = socket.socket(protocol, socket.SOCK_DGRAM)
if local_udp_addr is not None:
self.sock.bind(local_udp_addr)
self.sockets: Dict[int, Union[_Server, _Connection]] = {}
self.sockets: Dict[int, Union[_Server, _Connection]] = {} # Local port -> _Server / _Connection
self.stop = False
self.sockets_stop_lock = threading.RLock()
self.read_thread = threading.Thread(None, self.__background)
self.read_thread.start()
self.deb = f"{self.local_udp_addr}: "
logging.debug(self.deb + "Set up binding")
def __background(self) -> None:
poller = select.poll()
poller = select.poll() # Does not work on Windows because Python is stupid (WSAPoll is a thing)
poller.register(self.sock, select.POLLIN)
while True:
......@@ -447,18 +519,20 @@ class Binding:
if self.stop:
break
while (self.sock, select.POLLIN) in poll_result:
while len(poll_result) > 0:
data, addr = self.sock.recvfrom(header_size + payload_size, socket.MSG_WAITALL)
packet = _AddrPacket(_unpack_header(data[:header_size]), data[header_size:], addr)
logging.debug(self.deb + f"Packet from {addr}")
packet = _AddrPacket(_unpack_header(data[:header_size]),
data[header_size:], addr)
packet.data = packet.data[:-(payload_size - packet.header.data_length)]
if not packet.verify_checksum():
print("Invalid checksum")
logging.warning(self.deb + "Invalid checksum")
else:
if packet.header.src_port in self.sockets:
batch = packet_batches[packet.header.src_port] = []
batch.append(packet)
if packet.header.dest_port in self.sockets:
packet_batches.setdefault(packet.header.dest_port, []).append(packet)
else:
print("Unknown server")
logging.warning(self.deb + "Unknown server")
poll_result = poller.poll(0)
......@@ -472,23 +546,29 @@ class Binding:
self.sockets_stop_lock.release()
self.sockets_stop_lock.release()
def _send(self, packet: _Packet, remote_udp_addr: Any) -> None:
data = _pack_header(packet.header) + packet.data
logging.debug(self.deb + f"Send to {remote_udp_addr}")
data = _pack_header(packet.header) + packet.data + bytes(payload_size - len(packet.data))
while len(data) > 0:
data = data[self.sock.sendto(data, remote_udp_addr):]
def _remove_socket(self, local_port: int) -> None:
logging.debug(self.deb + f"Remove socket {local_port}")
self.sockets_stop_lock.acquire()
del self.sockets[local_port]
self.sockets_stop_lock.release()
def bind_server(self, local_btcp_port: int) -> _Server:
logging.debug(self.deb + f"Bind server to btcp {local_btcp_port}")
self.sockets_stop_lock.acquire()
server = self.sockets[local_btcp_port] = _Server(self, local_btcp_port)
self.sockets_stop_lock.release()
return server
def connect_client(self, local_btcp_port: int, remote_btcp_port: int, remote_udp_addr: Any) -> _Connection:
logging.debug(self.deb + f"Connect to btcp {local_btcp_port} -> {remote_btcp_port} {remote_udp_addr}")
self.sockets_stop_lock.acquire()
connection = _Connection(self, _Stream(self, local_btcp_port, remote_btcp_port, remote_udp_addr,
self.window_size, self.timeout))
......@@ -497,14 +577,17 @@ class Binding:
return connection
def close(self) -> None:
logging.debug(self.deb + "Closing...")
self.sockets_stop_lock.acquire()
self.stop = True
self.sockets_stop_lock.release()
self.read_thread.join()
logging.debug(self.deb + "Thread exited")
self.sockets_stop_lock.acquire()
for server in self.sockets.values():
server.close()
while len(self.sockets) > 0:
self.sockets[0].close()
self.sockets_stop_lock.release()
self.sock.close()
logging.debug(self.deb + "Closed")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment