btcp.py 28.6 KB
Newer Older
StevenWdV's avatar
StevenWdV committed
1
import binascii
2
import bisect
StevenWdV's avatar
StevenWdV committed
3
import copy
StevenWdV's avatar
StevenWdV committed
4
import logging
StevenWdV's avatar
StevenWdV committed
5
import random
6 7 8 9 10 11
import select
import socket
import struct
import threading
import time
from typing import *
StevenWdV's avatar
StevenWdV committed
12

13
header_format = "!HHHHBBHI"
StevenWdV's avatar
StevenWdV committed
14
header_size = 16
15
payload_size = 1000
16
max_seq = 0xffFF
17

18
enable_more_asserts = False
19
debug_window = False
20

21 22 23 24 25 26 27 28 29 30 31 32 33

class _Flags:
    def __init__(self, flags: Union[int, Tuple[bool, bool, bool]] = 0):
        """
        :param flags: Combined or (SYN, ACK, FIN)
        """
        if type(flags) is int:
            self.fin = bool(flags & 1)
            self.ack = bool(flags & 1 << 1)
            self.syn = bool(flags & 1 << 2)
        else:
            self.syn, self.ack, self.fin = flags

StevenWdV's avatar
StevenWdV committed
34
    def __int__(self) -> int:
35
        return self.syn << 2 | self.ack << 1 | self.fin
StevenWdV's avatar
StevenWdV committed
36

StevenWdV's avatar
StevenWdV committed
37 38 39 40 41 42 43 44 45 46
    def __str__(self) -> str:
        strs: List[str] = []
        if self.syn:
            strs.append("SYN")
        if self.ack:
            strs.append("ACK")
        if self.fin:
            strs.append("FIN")
        return "-".join(strs)

StevenWdV's avatar
StevenWdV committed
47

48
class _Header:
49 50 51 52
    @staticmethod
    def unpack(header: bytes) -> "_Header":
        return _Header(*struct.unpack(header_format, header))

53 54 55 56 57 58
    def __init__(self, src_port: int, dest_port: int, syn_nr: int, ack_nr: int, flags: int,
                 window: int, data_length: int, checksum: int = 0):
        self.dest_port = dest_port
        self.src_port = src_port
        self.syn_nr = syn_nr
        self.ack_nr = ack_nr
StevenWdV's avatar
StevenWdV committed
59 60
        self.flags = flags
        self.window = window
61
        self.data_length = data_length
StevenWdV's avatar
StevenWdV committed
62 63
        self.checksum = checksum

64 65
    def flags_obj(self) -> _Flags:
        return _Flags(self.flags)
StevenWdV's avatar
StevenWdV committed
66

67 68 69 70 71 72 73 74 75 76 77
    def __bytes__(self) -> bytes:
        return struct.pack(
            header_format,
            self.src_port,
            self.dest_port,
            self.syn_nr,
            self.ack_nr,
            self.flags,
            self.window,
            self.data_length,
            self.checksum)
StevenWdV's avatar
StevenWdV committed
78

StevenWdV's avatar
StevenWdV committed
79

80 81
class _Packet:
    def __init__(self, header: _Header, data: bytes):
StevenWdV's avatar
StevenWdV committed
82 83
        self.header = header
        self.data = data
StevenWdV's avatar
StevenWdV committed
84

StevenWdV's avatar
StevenWdV committed
85
    def compute_checksum(self) -> int:
86 87 88 89 90
        if self.header.checksum == 0:
            header_no_checksum = self.header
        else:
            header_no_checksum = copy.deepcopy(self.header)
            header_no_checksum.checksum = 0
91
        return binascii.crc32(bytes(header_no_checksum) + self.data)
StevenWdV's avatar
StevenWdV committed
92

StevenWdV's avatar
StevenWdV committed
93 94 95
    def verify_checksum(self) -> bool:
        return self.compute_checksum() == self.header.checksum

96
    def set_checksum(self) -> None:
StevenWdV's avatar
StevenWdV committed
97
        self.header.checksum = self.compute_checksum()
StevenWdV's avatar
StevenWdV committed
98

99
    def __lt__(self, other: "_Packet") -> bool:
100
        return self.header.syn_nr < other.header.syn_nr
StevenWdV's avatar
StevenWdV committed
101

StevenWdV's avatar
StevenWdV committed
102

103 104 105 106
class _TimestampPacket(_Packet):
    def __init__(self, header: _Header, data: bytes, timestamp: float):
        super().__init__(header, data)
        self.timestamp = timestamp
StevenWdV's avatar
StevenWdV committed
107 108


109 110 111 112
class _AddrPacket(_Packet):
    def __init__(self, header: _Header, data: bytes, remote_udp_addr: Tuple[str, int]):
        super().__init__(header, data)
        self.remote_udp_addr = remote_udp_addr
StevenWdV's avatar
StevenWdV committed
113 114


StevenWdV's avatar
StevenWdV committed
115
# Info for if the connection was not initiated from our side
116 117 118 119
class _RemoteInit:
    def __init__(self, syn_nr: int, window_size: int):
        self.syn_nr = syn_nr
        self.window_size = window_size
StevenWdV's avatar
StevenWdV committed
120 121


122 123 124
class _Stream:
    def __init__(self, binding: "Binding", local_port: int,
                 remote_port: int, remote_udp_addr: Any,
StevenWdV's avatar
StevenWdV committed
125
                 local_window_size: int, timeout: float,
126 127 128 129 130 131 132 133
                 remote_init: Optional[_RemoteInit] = None):
        self.binding = binding
        self.local_port = local_port
        self.remote_port = remote_port
        self.local_window_size = local_window_size
        self.remote_udp_addr = remote_udp_addr
        self.timeout = timeout
        self.expect_syn_ack = remote_init is None
StevenWdV's avatar
StevenWdV committed
134

135 136
        self.send_buffer: List[bytes] = []  # Data to be send
        self.sent_ack_buffer: List[_TimestampPacket] = []  # Data still to be ACKed by other
StevenWdV's avatar
StevenWdV committed
137

138 139
        self.receive_ack_buffer: List[_Packet] = []  # Received but not ACKed by me
        self.receive_buffer: List[bytes] = []  # Received and ACKed, but not yet retrieved
StevenWdV's avatar
StevenWdV committed
140

141
        self.send_buffer_lock = threading.Lock()
142
        self.send_buffer_empty = threading.Condition(self.send_buffer_lock)
143 144
        self.receive_buffer_lock = threading.Lock()
        self.receive_bytes_available = threading.Condition(self.receive_buffer_lock)
StevenWdV's avatar
StevenWdV committed
145

146 147 148 149 150 151 152 153 154
        self.all_acked_lock = threading.Lock()
        self.all_acked = threading.Condition(self.all_acked_lock)

        self.receiver_closed = False  # If the remote host has sent a FIN
        self.sender_closed = False  # If we have sent a FIN

        self.receiver_closed_lock = threading.Lock()
        self.receiver_closed_event = threading.Condition(self.receiver_closed_lock)

StevenWdV's avatar
StevenWdV committed
155 156
        self.deb = f"{self.binding.local_udp_addr}|{self.local_port}->{self.remote_port}: "

157
        self.first_send_syn_nr = random.randint(0, max_seq)
StevenWdV's avatar
StevenWdV committed
158 159 160

        if remote_init is None:  # We initiate the connection
            self.send_syn_nr = self.first_send_syn_nr  # Next to be sent
161
            self.send_ack_nr = self.send_syn_nr  # First not ACKed by other
StevenWdV's avatar
StevenWdV committed
162

163 164
            self.first_recv_syn_nr = None

165 166 167
            self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, False, False)))

        else:
StevenWdV's avatar
StevenWdV committed
168 169
            self.remote_window_size = remote_init.window_size

170
            self.recv_syn_nr = remote_init.syn_nr  # Last received
171
            self.recv_ack_nr = self.__next_seq(remote_init.syn_nr)  # First still to be received
StevenWdV's avatar
StevenWdV committed
172
            self.send_syn_nr = self.first_send_syn_nr  # Next to be sent
173 174
            self.send_ack_nr = self.send_syn_nr  # First not ACKed by other

StevenWdV's avatar
StevenWdV committed
175 176
            self.first_recv_syn_nr = self.recv_syn_nr

177 178
            self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, True, False)))

179
        self.send_syn_nr = self.__next_seq(self.send_syn_nr)
StevenWdV's avatar
StevenWdV committed
180
        logging.debug(self.deb + "Setup stream")
181

182 183 184 185 186 187 188 189 190
    @staticmethod
    def __next_seq(nr: int) -> int:
        return (nr + 1) % max_seq

    @staticmethod
    def __seq_between(nr: int, lower_bound: int, upper_bound: int) -> bool:
        nr %= max_seq
        lower_bound %= max_seq
        upper_bound %= max_seq
191
        if lower_bound <= upper_bound:
192 193 194 195
            return lower_bound <= nr <= upper_bound
        else:
            return nr >= lower_bound or nr <= upper_bound

196
    def __send_to_be_acked(self, data: bytes, syn_nr: int, flags=_Flags((False, True, False))) -> None:
StevenWdV's avatar
StevenWdV committed
197 198 199
        logging.debug(self.deb +
                      f"Sending #{syn_nr} {flags}{f' with ACK {self.recv_ack_nr}' if flags.ack else ''} "
                      f"with {len(data)} bytes")
200
        header = _Header(self.local_port, self.remote_port, syn_nr, self.recv_ack_nr if flags.ack else 0,
StevenWdV's avatar
StevenWdV committed
201 202
                         int(flags), self.local_window_size, len(data))
        packet = _TimestampPacket(header, data, time.perf_counter())
203 204 205 206 207 208
        packet.set_checksum()
        self.sent_ack_buffer.append(packet)
        # noinspection PyProtectedMember
        self.binding._send(packet, self.remote_udp_addr)

    def __send_ack(self) -> None:
StevenWdV's avatar
StevenWdV committed
209
        logging.debug(self.deb + f"Sending ACK {self.recv_ack_nr}")
210
        ack_header = _Header(self.local_port, self.remote_port, self.send_syn_nr, self.recv_ack_nr,
StevenWdV's avatar
StevenWdV committed
211
                             int(_Flags((False, True, False))), self.local_window_size, 0)
212 213 214 215 216 217 218

        ack_packet = _Packet(ack_header, b"")
        ack_packet.set_checksum()
        # noinspection PyProtectedMember
        self.binding._send(ack_packet, self.remote_udp_addr)

    def pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
StevenWdV's avatar
StevenWdV committed
219
        logging.debug(self.deb + f"Received {len(packets)} messages")
220 221 222 223
        send_ack = False

        for packet in packets:
            flags = packet.header.flags_obj()
StevenWdV's avatar
StevenWdV committed
224 225 226
            logging.debug(self.deb + f"Received #{packet.header.syn_nr} {flags}"
                          f"{f' with ACK {packet.header.ack_nr}' if flags.ack else ''} "
                          f"with {len(packet.data)} bytes")
227 228 229

            if self.expect_syn_ack:
                if flags.syn and flags.ack and not flags.fin:
230
                    if packet.header.ack_nr != self.__next_seq(self.send_ack_nr):
StevenWdV's avatar
StevenWdV committed
231
                        logging.warning(self.deb + "Wrong ACK nr in SYN-ACK")
232 233 234
                        continue
                    self.expect_syn_ack = False
                    self.recv_syn_nr = packet.header.syn_nr
235
                    self.recv_ack_nr = self.__next_seq(packet.header.syn_nr)
StevenWdV's avatar
StevenWdV committed
236
                    self.first_recv_syn_nr = self.recv_syn_nr
237
                    self.remote_window_size = packet.header.window
238
                    send_ack = True
StevenWdV's avatar
StevenWdV committed
239
                else:
StevenWdV's avatar
StevenWdV committed
240
                    logging.warning(self.deb + "SYN-ACK expected")
241 242
                    continue
            elif flags.syn:
243
                if self.first_recv_syn_nr is not None and packet.header.syn_nr == self.first_recv_syn_nr:
244
                    logging.warning(self.deb + "Spurious SYN")
StevenWdV's avatar
StevenWdV committed
245 246 247
                    continue
                else:
                    logging.warning(self.deb + "Unexpected SYN")
248
                    # (Would be simultaneous open if self.first_recv_syn_nr is None)
StevenWdV's avatar
StevenWdV committed
249
                    continue
StevenWdV's avatar
StevenWdV committed
250

251 252
            if not self.__seq_between(packet.header.syn_nr, self.recv_ack_nr - self.local_window_size,
                                      self.recv_ack_nr + self.local_window_size - 1):
253
                logging.info(self.deb + "SYN nr outside of window (spurious retransmission?)")
254 255
                continue

256 257 258
            if self.__seq_between(packet.header.syn_nr, self.recv_ack_nr + 1,
                                  self.recv_ack_nr + self.local_window_size - 1):
                self.recv_syn_nr = packet.header.syn_nr
StevenWdV's avatar
StevenWdV committed
259

260
            if flags.ack:
261
                if self.__seq_between(packet.header.ack_nr, self.send_ack_nr + 1, self.send_syn_nr):
262
                    # print(f"BEFORE ACK: {', '.join([str(packet.header.syn_nr) for packet in self.sent_ack_buffer])}")
263
                    # Delete ACKed messages
264 265
                    self.sent_ack_buffer = [
                        p for p in self.sent_ack_buffer
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
                        if not self.__seq_between(p.header.syn_nr, self.send_ack_nr, packet.header.ack_nr - 1)]
                    self.send_ack_nr = packet.header.ack_nr
                    # print(f"AFTER ACK: {', '.join([str(packet.header.syn_nr) for packet in self.sent_ack_buffer])}")
                    if self.send_ack_nr == self.send_syn_nr:
                        self.all_acked_lock.acquire()
                        self.all_acked.notify()
                        self.all_acked_lock.release()

            if flags.fin:
                self.receiver_closed_lock.acquire()
                self.receiver_closed = True
                self.receiver_closed_event.notify()
                self.receiver_closed_lock.release()

            if packet.header.data_length > 0 or flags.fin:
281 282 283
                if packet.header.syn_nr == self.recv_ack_nr:
                    # Move run of received packets to receive buffer
                    self.receive_buffer_lock.acquire()
284 285
                    if packet.header.data_length > 0:
                        self.receive_buffer.append(packet.data)
286
                    self.recv_ack_nr = self.__next_seq(self.recv_ack_nr)
287 288 289 290 291 292 293

                    index = bisect.bisect_left(self.receive_ack_buffer, packet)
                    while (index < len(self.receive_ack_buffer)
                           and self.receive_ack_buffer[index].header.syn_nr == self.recv_ack_nr):
                        packet = self.receive_ack_buffer.pop(index)
                        if packet.header.data_length > 0:
                            self.receive_buffer.append(packet.data)
294
                        self.recv_ack_nr = self.__next_seq(self.recv_ack_nr)
295 296 297
                        if len(self.receive_ack_buffer) > 0:
                            index %= len(self.receive_ack_buffer)

298 299 300
                    self.receive_bytes_available.notify()
                    self.receive_buffer_lock.release()

301 302
                elif self.__seq_between(packet.header.syn_nr, self.recv_ack_nr + 1,
                                        self.recv_ack_nr + self.local_window_size - 1):
303 304
                    # There is a gap, store this packet
                    insert_index = bisect.bisect_left(self.receive_ack_buffer, packet)
StevenWdV's avatar
StevenWdV committed
305

306 307
                    if (insert_index < len(self.receive_ack_buffer)
                            and self.receive_ack_buffer[insert_index].header.syn_nr == packet.header.syn_nr):
308 309 310
                        if enable_more_asserts:
                            assert self.receive_ack_buffer[insert_index].data == packet.data
                        logging.info(self.deb + "Spurious retransmission of packet in window")
StevenWdV's avatar
StevenWdV committed
311
                    else:
312
                        self.receive_ack_buffer.insert(insert_index, packet)
313
                else:
314
                    logging.info(self.deb + "Spurious retransmission of packet before window")
315 316 317 318 319 320 321 322

                send_ack = True

        if send_ack:
            # Don't send ACK now if we can piggyback it on the data
            self.send_buffer_lock.acquire()
            data_to_send = len(self.send_buffer) > 0
            self.send_buffer_lock.release()
StevenWdV's avatar
StevenWdV committed
323
            if not data_to_send:
324 325
                self.__send_ack()

326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
        if debug_window:
            print(self.deb + f"ACK {self.recv_ack_nr}")
            syn = self.recv_ack_nr
            if len(self.receive_ack_buffer) > 0:
                index = bisect.bisect_left(self.receive_ack_buffer, _Packet(_Header(0, 0, syn, 0, 0, 0, 0, 0), b""))
                while syn < self.recv_ack_nr + self.local_window_size:
                    while (self.receive_ack_buffer[index].header.syn_nr != syn
                           < self.recv_ack_nr + self.local_window_size):
                        print(".", end="")
                        syn += 1
                    while (self.receive_ack_buffer[index].header.syn_nr == syn
                           < self.recv_ack_nr + self.local_window_size):
                        print("R", end="")
                        syn += 1
                        index = (index + 1) % len(self.receive_ack_buffer)
                print()
            else:
                print("." * self.local_window_size)

345
    def poll_sender(self) -> None:
346 347 348 349 350 351
        curtime = time.perf_counter()
        lost = [p for p in self.sent_ack_buffer if curtime - p.timestamp >= self.timeout]
        self.sent_ack_buffer = [p for p in self.sent_ack_buffer if curtime - p.timestamp < self.timeout]

        # print(f"LOST: {', '.join([str(packet.header.syn_nr) for packet in lost])}")
        # print(f"NOT LOST: {', '.join([str(packet.header.syn_nr) for packet in self.sent_ack_buffer])}")
StevenWdV's avatar
StevenWdV committed
352
        if len(lost) > 0:
353
            logging.warning(self.deb + f"{len(lost)} lost messages")
StevenWdV's avatar
StevenWdV committed
354 355 356
        for p in lost:
            p.timestamp = time.perf_counter()
            self.__send_to_be_acked(p.data, p.header.syn_nr, p.header.flags_obj())
357
        # print(f"SENT to ACK: {', '.join([str(packet.header.syn_nr) for packet in self.sent_ack_buffer])}\n")
StevenWdV's avatar
StevenWdV committed
358

359 360 361
        if self.expect_syn_ack:
            return

StevenWdV's avatar
StevenWdV committed
362 363
        if len(self.send_buffer) > 0:
            logging.debug(self.deb + f"{len(self.send_buffer)} unsent messages")
364 365 366

        if len(self.sent_ack_buffer) < self.remote_window_size:
            self.send_buffer_lock.acquire()
StevenWdV's avatar
StevenWdV committed
367
            while len(self.send_buffer) > 0 and len(self.sent_ack_buffer) < self.remote_window_size:
368
                data = self.send_buffer.pop(0)
StevenWdV's avatar
StevenWdV committed
369
                self.__send_to_be_acked(data, self.send_syn_nr)
370
                self.send_syn_nr = self.__next_seq(self.send_syn_nr)
371 372
            if len(self.send_buffer) == 0:
                self.send_buffer_empty.notify()
373 374 375
            self.send_buffer_lock.release()

    def enqueue_data(self, data: bytes) -> None:
StevenWdV's avatar
StevenWdV committed
376
        logging.debug(self.deb + f"Enqueuing {len(data)} bytes")
377
        assert not self.sender_closed
378 379 380
        self.send_buffer_lock.acquire()
        offset = 0
        while offset < len(data):
381
            self.send_buffer.append(data[offset: min(offset + payload_size, len(data))])
382 383 384 385 386 387
            offset += payload_size
        self.send_buffer_lock.release()

    def __at_least_received(self, count: int) -> bool:
        i = 0
        while i < len(self.receive_buffer) and count > 0:
StevenWdV's avatar
StevenWdV committed
388
            count -= len(self.receive_buffer[i])
389
            i += 1
390
        logging.debug(self.deb + f"{count} bytes short")
391 392
        return count <= 0

393
    def dequeue_data_array(self, count: int) -> List[bytes]:
StevenWdV's avatar
StevenWdV committed
394
        logging.debug(self.deb + f"Dequeuing {count} bytes...")
395 396 397
        self.receive_buffer_lock.acquire()
        self.receive_bytes_available.wait_for(lambda: self.__at_least_received(count))

398 399 400 401 402
        data = []
        bytes_received = 0
        while bytes_received + len(self.receive_buffer[0]) < count:
            bytes_received += len(self.receive_buffer[0])
            data.append(self.receive_buffer.pop(0))
403

404 405 406
        if bytes_received < count:
            bytes_short = count - bytes_received
            data.append(self.receive_buffer[0][:bytes_short])
StevenWdV's avatar
StevenWdV committed
407
            self.receive_buffer[0] = self.receive_buffer[0][bytes_short:]
408 409

        self.receive_buffer_lock.release()
StevenWdV's avatar
StevenWdV committed
410
        logging.debug(self.deb + "Dequeued")
411 412
        return data

413 414 415
    def dequeue_data(self, count: int) -> bytes:
        return b"".join(self.dequeue_data_array(count))

416 417 418 419 420 421 422
    def dequeue_is_available(self, count: int) -> bool:
        self.receive_buffer_lock.acquire()
        available = self.__at_least_received(count)
        self.receive_buffer_lock.release()
        return available

    def dequeue_all_data(self) -> bytes:
StevenWdV's avatar
StevenWdV committed
423
        logging.debug(self.deb + "Dequeuing all data...")
424 425 426 427 428
        self.receive_buffer_lock.acquire()
        data = b""
        while len(self.receive_buffer) > 0:
            data += self.receive_buffer.pop(0)
        self.receive_buffer_lock.release()
StevenWdV's avatar
StevenWdV committed
429
        logging.debug(self.deb + f"Dequeued {len(data)} bytes")
430 431
        return data

432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
    def close_sender(self) -> None:
        logging.debug(self.deb + "Close sender")
        if self.sender_closed:
            logging.debug(self.deb + "Sender is already closed")
            return

        logging.debug(self.deb + "Waiting until send_buffer is empty")
        self.send_buffer_lock.acquire()
        self.send_buffer_empty.wait_for(lambda: len(self.send_buffer) == 0)
        self.send_buffer_lock.release()

        logging.debug(self.deb + "Waiting until sent_ack_buffer is empty")
        self.all_acked_lock.acquire()
        self.all_acked.wait_for(lambda: self.send_ack_nr == self.send_syn_nr)
        self.all_acked_lock.release()

        self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((False, True, True)))
        self.send_syn_nr = self.__next_seq(self.send_syn_nr)
        self.sender_closed = True

    def close_receiver(self) -> None:
453
        """Call after close_sender"""
454 455 456 457
        logging.debug(self.deb + "Waiting for remote host to close")
        self.receiver_closed_lock.acquire()
        self.receiver_closed_event.wait_for(lambda: self.receiver_closed)
        self.receiver_closed_lock.release()
458

459 460
        time.sleep(self.timeout * 4)

461 462 463 464 465

class _Connection:
    def __init__(self, parent: Union["_Server", "Binding"], stream: _Stream):
        self.parent = parent
        self.stream = stream
StevenWdV's avatar
StevenWdV committed
466 467 468 469 470
        self.deb = f"{self.__parent_str()}|{self.stream.local_port}->{self.stream.remote_port}: "
        logging.debug(self.deb + "Setup connection")

    def __parent_str(self) -> str:
        return self.parent.local_udp_addr if type(self.parent) is Binding else self.parent.binding.local_udp_addr
471 472 473 474 475 476 477 478 479 480

    def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
        self.stream.pass_received_msgs(packets)

    def _poll_streams(self) -> None:
        self.stream.poll_sender()

    def send(self, data: bytes) -> None:
        self.stream.enqueue_data(data)

481 482 483
    def receive_array(self, count: int) -> List[bytes]:
        return self.stream.dequeue_data_array(count)

484 485 486 487 488 489 490 491 492
    def receive(self, count: int) -> bytes:
        return self.stream.dequeue_data(count)

    def receive_is_available(self, count: int) -> bool:
        return self.stream.dequeue_is_available(count)

    def receive_all(self) -> bytes:
        return self.stream.dequeue_all_data()

493 494 495 496
    def close_sender(self) -> None:
        logging.debug(self.deb + "Close sender")
        self.stream.close_sender()

497
    def close(self) -> None:
498 499 500
        logging.debug(self.deb + "Closing connection...")
        self.stream.close_sender()
        self.stream.close_receiver()
501 502 503 504 505
        if type(self.parent) is _Server:
            # noinspection PyProtectedMember
            self.parent._remove_stream(self.stream.remote_port)
        else:
            # noinspection PyProtectedMember
506
            self.parent._remove_socket(self.stream.local_port)
StevenWdV's avatar
StevenWdV committed
507
        logging.debug(self.deb + "Closed")
508 509 510 511 512 513


class _Server:
    def __init__(self, binding: "Binding", local_port: int):
        self.binding = binding
        self.local_port = local_port
StevenWdV's avatar
StevenWdV committed
514
        self.streams: Dict[int, _Stream] = {}  # Remote port -> _Stream
515 516 517 518 519
        self.streams_lock = threading.Lock()

        self.backlog = 0
        self.backlog_connections: List[_Connection] = []
        self.backlog_lock = threading.Lock()
StevenWdV's avatar
StevenWdV committed
520 521 522 523
        self.backlog_available = threading.Condition(self.backlog_lock)

        self.deb = f"{self.binding.local_udp_addr}|{self.local_port}: "
        logging.debug(self.deb + "Setup server")
524 525

    def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
StevenWdV's avatar
StevenWdV committed
526 527
        logging.debug(self.deb + f"Received {len(packets)} messages")

528 529 530 531 532 533 534
        packet_batches: Dict[int, List[_AddrPacket]] = {}
        orphan_packets: List[_AddrPacket] = []

        self.streams_lock.acquire()

        for packet in packets:
            if packet.header.src_port in self.streams:
StevenWdV's avatar
StevenWdV committed
535
                packet_batches.setdefault(packet.header.src_port, []).append(packet)
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550
            else:
                orphan_packets.append(packet)

        for remote_port, packets in packet_batches.items():
            self.streams[remote_port].pass_received_msgs(packets)

        for packet in orphan_packets:
            flags = packet.header.flags_obj()
            if (flags.syn and not flags.ack and not flags.fin
                    and packet.header.data_length == 0):
                self.backlog_lock.acquire()
                if self.backlog > 0:
                    self.streams[packet.header.src_port] = stream = _Stream(
                        self.binding, self.local_port,
                        packet.header.src_port, packet.remote_udp_addr,
551
                        self.binding.window_size, self.binding.timeout_sec,
552 553 554
                        _RemoteInit(packet.header.syn_nr, packet.header.window))
                    self.backlog_connections.append(_Connection(self, stream))
                    self.backlog -= 1
StevenWdV's avatar
StevenWdV committed
555
                    self.backlog_available.notify()
556 557 558
                    self.backlog_lock.release()
                else:
                    self.backlog_lock.release()
StevenWdV's avatar
StevenWdV committed
559
                    logging.warning(self.deb + "Backlog full")
560
            else:
StevenWdV's avatar
StevenWdV committed
561
                logging.warning(self.deb + "Unknown connection")
562 563 564 565 566 567 568 569 570 571

        self.streams_lock.release()

    def _poll_streams(self) -> None:
        self.streams_lock.acquire()
        for stream in self.streams.values():
            stream.poll_sender()
        self.streams_lock.release()

    def _remove_stream(self, remote_port: int) -> None:
StevenWdV's avatar
StevenWdV committed
572
        logging.debug(self.deb + "Remove stream")
573 574 575 576 577
        self.streams_lock.acquire()
        del self.streams[remote_port]
        self.streams_lock.release()

    def start_listen(self, backlog: int) -> None:
StevenWdV's avatar
StevenWdV committed
578
        logging.debug(self.deb + "Start listen")
579 580 581 582 583
        self.backlog_lock.acquire()
        self.backlog = backlog
        self.backlog_lock.release()

    def accept(self) -> _Connection:
StevenWdV's avatar
StevenWdV committed
584
        logging.debug(self.deb + "Accepting...")
585
        self.backlog_lock.acquire()
StevenWdV's avatar
StevenWdV committed
586
        self.backlog_available.wait_for(lambda: len(self.backlog_connections) > 0)
587 588
        connection = self.backlog_connections.pop(0)
        self.backlog_lock.release()
StevenWdV's avatar
StevenWdV committed
589
        logging.debug(self.deb + "Accepted")
590 591 592
        return connection

    def close(self) -> None:
StevenWdV's avatar
StevenWdV committed
593
        logging.debug(self.deb + "Closing...")
594 595 596 597
        while len(self.streams) > 0:
            self.streams[0].close_sender()
            self.streams[0].close_receiver()
            self.streams.pop(0)
598 599
        # noinspection PyProtectedMember
        self.binding._remove_socket(self.local_port)
StevenWdV's avatar
StevenWdV committed
600
        logging.debug(self.deb + "Closed")
601 602 603 604


class Binding:
    def __init__(self, protocol: int, local_udp_addr: Optional[Any],
605
                 window_size: int, timeout_ms: int, poll_time_ms: Optional[int] = None):
StevenWdV's avatar
StevenWdV committed
606
        self.local_udp_addr = local_udp_addr
607
        self.window_size = window_size
608
        self.timeout_sec = timeout_ms / 1000
609
        self.poll_time_ms = poll_time_ms or min(25., timeout_ms / 4)
610 611 612 613 614

        self.sock = socket.socket(protocol, socket.SOCK_DGRAM)
        if local_udp_addr is not None:
            self.sock.bind(local_udp_addr)

StevenWdV's avatar
StevenWdV committed
615
        self.sockets: Dict[int, Union[_Server, _Connection]] = {}  # Local port -> _Server / _Connection
616
        self.stop = False
617
        self.sockets_stop_lock = threading.Lock()
618 619

        self.read_thread = threading.Thread(None, self.__background)
StevenWdV's avatar
StevenWdV committed
620 621
        self.read_thread.start()

StevenWdV's avatar
StevenWdV committed
622 623 624
        self.deb = f"{self.local_udp_addr}: "
        logging.debug(self.deb + "Set up binding")

625
    def __background(self) -> None:
StevenWdV's avatar
StevenWdV committed
626
        poller = select.poll()  # Does not work on Windows because Python is stupid (WSAPoll is a thing)
627 628 629
        poller.register(self.sock, select.POLLIN)

        while True:
630
            poll_result = poller.poll(self.poll_time_ms)
631 632 633 634 635 636 637

            packet_batches: Dict[int, List[_AddrPacket]] = {}

            self.sockets_stop_lock.acquire()
            if self.stop:
                break

StevenWdV's avatar
StevenWdV committed
638
            while len(poll_result) > 0:
639
                data, addr = self.sock.recvfrom(header_size + payload_size, socket.MSG_WAITALL)
StevenWdV's avatar
StevenWdV committed
640
                logging.debug(self.deb + f"Packet from {addr}")
641
                packet = _AddrPacket(_Header.unpack(data[:header_size]),
StevenWdV's avatar
StevenWdV committed
642
                                     data[header_size:], addr)
643 644
                if packet.header.data_length > payload_size:
                    logging.warning(self.deb + "data_length too large")
645
                packet.data = packet.data[:packet.header.data_length]
646 647

                if not packet.verify_checksum():
StevenWdV's avatar
StevenWdV committed
648
                    logging.warning(self.deb + "Invalid checksum")
649
                else:
StevenWdV's avatar
StevenWdV committed
650 651
                    if packet.header.dest_port in self.sockets:
                        packet_batches.setdefault(packet.header.dest_port, []).append(packet)
652
                    else:
StevenWdV's avatar
StevenWdV committed
653
                        logging.warning(self.deb + "Unknown server")
654 655 656 657 658 659 660 661 662 663 664 665 666

                poll_result = poller.poll(0)

            for local_port, packets in packet_batches.items():
                # noinspection PyProtectedMember
                self.sockets[local_port]._pass_received_msgs(packets)

            for server in self.sockets.values():
                # noinspection PyProtectedMember
                server._poll_streams()

            self.sockets_stop_lock.release()

StevenWdV's avatar
StevenWdV committed
667 668
        self.sockets_stop_lock.release()

669
    def _send(self, packet: _Packet, remote_udp_addr: Any) -> None:
StevenWdV's avatar
StevenWdV committed
670
        logging.debug(self.deb + f"Send to {remote_udp_addr}")
671
        data = bytes(packet.header) + packet.data + bytes(payload_size - len(packet.data))
672 673 674 675
        while len(data) > 0:
            data = data[self.sock.sendto(data, remote_udp_addr):]

    def _remove_socket(self, local_port: int) -> None:
StevenWdV's avatar
StevenWdV committed
676
        logging.debug(self.deb + f"Remove socket {local_port}")
677 678 679 680 681
        self.sockets_stop_lock.acquire()
        del self.sockets[local_port]
        self.sockets_stop_lock.release()

    def bind_server(self, local_btcp_port: int) -> _Server:
StevenWdV's avatar
StevenWdV committed
682
        logging.debug(self.deb + f"Bind server to btcp {local_btcp_port}")
683 684 685 686 687 688
        self.sockets_stop_lock.acquire()
        server = self.sockets[local_btcp_port] = _Server(self, local_btcp_port)
        self.sockets_stop_lock.release()
        return server

    def connect_client(self, local_btcp_port: int, remote_btcp_port: int, remote_udp_addr: Any) -> _Connection:
StevenWdV's avatar
StevenWdV committed
689
        logging.debug(self.deb + f"Connect to btcp {local_btcp_port} -> {remote_btcp_port} {remote_udp_addr}")
690 691
        self.sockets_stop_lock.acquire()
        connection = _Connection(self, _Stream(self, local_btcp_port, remote_btcp_port, remote_udp_addr,
692
                                               self.window_size, self.timeout_sec))
693 694 695 696 697
        self.sockets[local_btcp_port] = connection
        self.sockets_stop_lock.release()
        return connection

    def close(self) -> None:
StevenWdV's avatar
StevenWdV committed
698
        logging.debug(self.deb + "Closing...")
699 700

        while len(self.sockets) > 0:
701 702 703 704
            key = None
            for key in self.sockets.keys():
                break
            self.sockets[key].close()
705 706
        self.sock.close()

707 708 709 710 711
        self.sockets_stop_lock.acquire()
        self.stop = True
        self.sockets_stop_lock.release()

        self.read_thread.join()
StevenWdV's avatar
StevenWdV committed
712
        logging.debug(self.deb + "Closed")