Commit b8793705 authored by StevenWdV's avatar StevenWdV
Browse files

Moved header packing functions into _Header, decreased first SYN nr range,...

Moved header packing functions into _Header, decreased first SYN nr range, added too low SYN check, added polling time param, added data_length check
parent f7bb06c0
...@@ -15,7 +15,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds ...@@ -15,7 +15,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds
parser.add_argument("-i", "--input", help="File to send", default="tmp.file") parser.add_argument("-i", "--input", help="File to send", default="tmp.file")
args = parser.parse_args() args = parser.parse_args()
binding = btcp.Binding(socket.AF_INET, None, args.window, args.timeout / 1000.) binding = btcp.Binding(socket.AF_INET, None, args.window, args.timeout)
connection = binding.connect_client(0, 0, ("", 9001)) connection = binding.connect_client(0, 0, ("", 9001))
file = open(args.input, "r+b") file = open(args.input, "r+b")
......
...@@ -16,7 +16,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds ...@@ -16,7 +16,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds
parser.add_argument("-o", "--output", help="Where to store file", default="tmp.file") parser.add_argument("-o", "--output", help="Where to store file", default="tmp.file")
args = parser.parse_args() args = parser.parse_args()
binding = btcp.Binding(socket.AF_INET, ("", 9001), args.window, args.timeout / 1000.) binding = btcp.Binding(socket.AF_INET, ("", 9001), args.window, args.timeout)
server = binding.bind_server(0) server = binding.bind_server(0)
server.start_listen(1) server.start_listen(1)
......
...@@ -42,6 +42,10 @@ class _Flags: ...@@ -42,6 +42,10 @@ class _Flags:
class _Header: class _Header:
@staticmethod
def unpack(header: bytes) -> "_Header":
return _Header(*struct.unpack(header_format, header))
def __init__(self, src_port: int, dest_port: int, syn_nr: int, ack_nr: int, flags: int, def __init__(self, src_port: int, dest_port: int, syn_nr: int, ack_nr: int, flags: int,
window: int, data_length: int, checksum: int = 0): window: int, data_length: int, checksum: int = 0):
self.dest_port = dest_port self.dest_port = dest_port
...@@ -56,22 +60,17 @@ class _Header: ...@@ -56,22 +60,17 @@ class _Header:
def flags_obj(self) -> _Flags: def flags_obj(self) -> _Flags:
return _Flags(self.flags) return _Flags(self.flags)
def __bytes__(self) -> bytes:
def _pack_header(header: _Header) -> bytes: return struct.pack(
return struct.pack( header_format,
header_format, self.src_port,
header.src_port, self.dest_port,
header.dest_port, self.syn_nr,
header.syn_nr, self.ack_nr,
header.ack_nr, self.flags,
header.flags, self.window,
header.window, self.data_length,
header.data_length, self.checksum)
header.checksum)
def _unpack_header(header: bytes) -> _Header:
return _Header(*struct.unpack(header_format, header))
class _Packet: class _Packet:
...@@ -85,7 +84,7 @@ class _Packet: ...@@ -85,7 +84,7 @@ class _Packet:
else: else:
header_no_checksum = copy.deepcopy(self.header) header_no_checksum = copy.deepcopy(self.header)
header_no_checksum.checksum = 0 header_no_checksum.checksum = 0
return binascii.crc32(_pack_header(header_no_checksum) + self.data) return binascii.crc32(bytes(header_no_checksum) + self.data)
def verify_checksum(self) -> bool: def verify_checksum(self) -> bool:
return self.compute_checksum() == self.header.checksum return self.compute_checksum() == self.header.checksum
...@@ -116,6 +115,8 @@ class _RemoteInit: ...@@ -116,6 +115,8 @@ class _RemoteInit:
self.window_size = window_size self.window_size = window_size
# TODO more timeouts
# TODO fast retransmit
class _Stream: class _Stream:
def __init__(self, binding: "Binding", local_port: int, def __init__(self, binding: "Binding", local_port: int,
remote_port: int, remote_udp_addr: Any, remote_port: int, remote_udp_addr: Any,
...@@ -141,7 +142,7 @@ class _Stream: ...@@ -141,7 +142,7 @@ class _Stream:
self.deb = f"{self.binding.local_udp_addr}|{self.local_port}->{self.remote_port}: " self.deb = f"{self.binding.local_udp_addr}|{self.local_port}->{self.remote_port}: "
self.first_send_syn_nr = random.randint(0, 0xffFF) self.first_send_syn_nr = random.randint(0, 0xff)
if remote_init is None: # We initiate the connection if remote_init is None: # We initiate the connection
self.send_syn_nr = self.first_send_syn_nr # Next to be sent self.send_syn_nr = self.first_send_syn_nr # Next to be sent
...@@ -221,6 +222,10 @@ class _Stream: ...@@ -221,6 +222,10 @@ class _Stream:
# TODO simultaneous open if self.first_recv_syn_nr is None # TODO simultaneous open if self.first_recv_syn_nr is None
continue continue
if packet.header.syn_nr < self.first_recv_syn_nr:
logging.warning(self.deb + "Too low SYN")
continue
if packet.header.syn_nr > self.recv_syn_nr + self.local_window_size: if packet.header.syn_nr > self.recv_syn_nr + self.local_window_size:
logging.warning(self.deb + "Too high SYN") logging.warning(self.deb + "Too high SYN")
continue continue
...@@ -437,7 +442,7 @@ class _Server: ...@@ -437,7 +442,7 @@ class _Server:
self.streams[packet.header.src_port] = stream = _Stream( self.streams[packet.header.src_port] = stream = _Stream(
self.binding, self.local_port, self.binding, self.local_port,
packet.header.src_port, packet.remote_udp_addr, packet.header.src_port, packet.remote_udp_addr,
self.binding.window_size, self.binding.timeout, self.binding.window_size, self.binding.timeout_sec,
_RemoteInit(packet.header.syn_nr, packet.header.window)) _RemoteInit(packet.header.syn_nr, packet.header.window))
self.backlog_connections.append(_Connection(self, stream)) self.backlog_connections.append(_Connection(self, stream))
self.backlog -= 1 self.backlog -= 1
...@@ -489,10 +494,11 @@ class _Server: ...@@ -489,10 +494,11 @@ class _Server:
class Binding: class Binding:
def __init__(self, protocol: int, local_udp_addr: Optional[Any], def __init__(self, protocol: int, local_udp_addr: Optional[Any],
window_size: int, timeout_sec: float): window_size: int, timeout_ms: int, poll_time_ms: Optional[int] = None):
self.local_udp_addr = local_udp_addr self.local_udp_addr = local_udp_addr
self.window_size = window_size self.window_size = window_size
self.timeout = timeout_sec self.timeout_sec = timeout_ms / 1000
self.poll_time_ms = poll_time_ms or timeout_ms
self.sock = socket.socket(protocol, socket.SOCK_DGRAM) self.sock = socket.socket(protocol, socket.SOCK_DGRAM)
if local_udp_addr is not None: if local_udp_addr is not None:
...@@ -513,7 +519,7 @@ class Binding: ...@@ -513,7 +519,7 @@ class Binding:
poller.register(self.sock, select.POLLIN) poller.register(self.sock, select.POLLIN)
while True: while True:
poll_result = poller.poll(100) poll_result = poller.poll(self.poll_time_ms)
packet_batches: Dict[int, List[_AddrPacket]] = {} packet_batches: Dict[int, List[_AddrPacket]] = {}
...@@ -524,8 +530,10 @@ class Binding: ...@@ -524,8 +530,10 @@ class Binding:
while len(poll_result) > 0: while len(poll_result) > 0:
data, addr = self.sock.recvfrom(header_size + payload_size, socket.MSG_WAITALL) data, addr = self.sock.recvfrom(header_size + payload_size, socket.MSG_WAITALL)
logging.debug(self.deb + f"Packet from {addr}") logging.debug(self.deb + f"Packet from {addr}")
packet = _AddrPacket(_unpack_header(data[:header_size]), packet = _AddrPacket(_Header.unpack(data[:header_size]),
data[header_size:], addr) data[header_size:], addr)
if packet.header.data_length > payload_size:
logging.warning(self.deb + "data_length too large")
packet.data = packet.data[:-(payload_size - packet.header.data_length)] packet.data = packet.data[:-(payload_size - packet.header.data_length)]
if not packet.verify_checksum(): if not packet.verify_checksum():
...@@ -552,7 +560,7 @@ class Binding: ...@@ -552,7 +560,7 @@ class Binding:
def _send(self, packet: _Packet, remote_udp_addr: Any) -> None: def _send(self, packet: _Packet, remote_udp_addr: Any) -> None:
logging.debug(self.deb + f"Send to {remote_udp_addr}") logging.debug(self.deb + f"Send to {remote_udp_addr}")
data = _pack_header(packet.header) + packet.data + bytes(payload_size - len(packet.data)) data = bytes(packet.header) + packet.data + bytes(payload_size - len(packet.data))
while len(data) > 0: while len(data) > 0:
data = data[self.sock.sendto(data, remote_udp_addr):] data = data[self.sock.sendto(data, remote_udp_addr):]
...@@ -573,7 +581,7 @@ class Binding: ...@@ -573,7 +581,7 @@ class Binding:
logging.debug(self.deb + f"Connect to btcp {local_btcp_port} -> {remote_btcp_port} {remote_udp_addr}") logging.debug(self.deb + f"Connect to btcp {local_btcp_port} -> {remote_btcp_port} {remote_udp_addr}")
self.sockets_stop_lock.acquire() self.sockets_stop_lock.acquire()
connection = _Connection(self, _Stream(self, local_btcp_port, remote_btcp_port, remote_udp_addr, connection = _Connection(self, _Stream(self, local_btcp_port, remote_btcp_port, remote_udp_addr,
self.window_size, self.timeout)) self.window_size, self.timeout_sec))
self.sockets[local_btcp_port] = connection self.sockets[local_btcp_port] = connection
self.sockets_stop_lock.release() self.sockets_stop_lock.release()
return connection return connection
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment