btcp.py 28.4 KB
Newer Older
StevenWdV's avatar
StevenWdV committed
1
import binascii
2
import bisect
StevenWdV's avatar
StevenWdV committed
3
import copy
StevenWdV's avatar
StevenWdV committed
4
import logging
StevenWdV's avatar
StevenWdV committed
5
import random
6 7 8 9 10 11
import select
import socket
import struct
import threading
import time
from typing import *
StevenWdV's avatar
StevenWdV committed
12

13 14
# TODO test btcp port multiplexing

15
header_format = "!HHHHBBHI"
StevenWdV's avatar
StevenWdV committed
16
header_size = 16
17
payload_size = 1000
18
max_seq = 0xffFF
19

20
debug_window = False
21

22 23 24 25 26 27 28 29 30 31 32 33 34

class _Flags:
    def __init__(self, flags: Union[int, Tuple[bool, bool, bool]] = 0):
        """
        :param flags: Combined or (SYN, ACK, FIN)
        """
        if type(flags) is int:
            self.fin = bool(flags & 1)
            self.ack = bool(flags & 1 << 1)
            self.syn = bool(flags & 1 << 2)
        else:
            self.syn, self.ack, self.fin = flags

StevenWdV's avatar
StevenWdV committed
35
    def __int__(self) -> int:
36
        return self.syn << 2 | self.ack << 1 | self.fin
StevenWdV's avatar
StevenWdV committed
37

StevenWdV's avatar
StevenWdV committed
38 39 40 41 42 43 44 45 46 47
    def __str__(self) -> str:
        strs: List[str] = []
        if self.syn:
            strs.append("SYN")
        if self.ack:
            strs.append("ACK")
        if self.fin:
            strs.append("FIN")
        return "-".join(strs)

StevenWdV's avatar
StevenWdV committed
48

49
class _Header:
50 51 52 53
    @staticmethod
    def unpack(header: bytes) -> "_Header":
        return _Header(*struct.unpack(header_format, header))

54 55 56 57 58 59
    def __init__(self, src_port: int, dest_port: int, syn_nr: int, ack_nr: int, flags: int,
                 window: int, data_length: int, checksum: int = 0):
        self.dest_port = dest_port
        self.src_port = src_port
        self.syn_nr = syn_nr
        self.ack_nr = ack_nr
StevenWdV's avatar
StevenWdV committed
60 61
        self.flags = flags
        self.window = window
62
        self.data_length = data_length
StevenWdV's avatar
StevenWdV committed
63 64
        self.checksum = checksum

65 66
    def flags_obj(self) -> _Flags:
        return _Flags(self.flags)
StevenWdV's avatar
StevenWdV committed
67

68 69 70 71 72 73 74 75 76 77 78
    def __bytes__(self) -> bytes:
        return struct.pack(
            header_format,
            self.src_port,
            self.dest_port,
            self.syn_nr,
            self.ack_nr,
            self.flags,
            self.window,
            self.data_length,
            self.checksum)
StevenWdV's avatar
StevenWdV committed
79

StevenWdV's avatar
StevenWdV committed
80

81 82
class _Packet:
    def __init__(self, header: _Header, data: bytes):
StevenWdV's avatar
StevenWdV committed
83 84
        self.header = header
        self.data = data
StevenWdV's avatar
StevenWdV committed
85

StevenWdV's avatar
StevenWdV committed
86
    def compute_checksum(self) -> int:
87 88 89 90 91
        if self.header.checksum == 0:
            header_no_checksum = self.header
        else:
            header_no_checksum = copy.deepcopy(self.header)
            header_no_checksum.checksum = 0
92
        return binascii.crc32(bytes(header_no_checksum) + self.data)
StevenWdV's avatar
StevenWdV committed
93

StevenWdV's avatar
StevenWdV committed
94 95 96
    def verify_checksum(self) -> bool:
        return self.compute_checksum() == self.header.checksum

97
    def set_checksum(self) -> None:
StevenWdV's avatar
StevenWdV committed
98
        self.header.checksum = self.compute_checksum()
StevenWdV's avatar
StevenWdV committed
99

100
    def __lt__(self, other: "_Packet") -> bool:
101
        return self.header.syn_nr < other.header.syn_nr
StevenWdV's avatar
StevenWdV committed
102

StevenWdV's avatar
StevenWdV committed
103

104 105 106 107
class _TimestampPacket(_Packet):
    def __init__(self, header: _Header, data: bytes, timestamp: float):
        super().__init__(header, data)
        self.timestamp = timestamp
StevenWdV's avatar
StevenWdV committed
108 109


110 111 112 113
class _AddrPacket(_Packet):
    def __init__(self, header: _Header, data: bytes, remote_udp_addr: Tuple[str, int]):
        super().__init__(header, data)
        self.remote_udp_addr = remote_udp_addr
StevenWdV's avatar
StevenWdV committed
114 115


StevenWdV's avatar
StevenWdV committed
116
# Info for if the connection was not initiated from our side
117 118 119 120
class _RemoteInit:
    def __init__(self, syn_nr: int, window_size: int):
        self.syn_nr = syn_nr
        self.window_size = window_size
StevenWdV's avatar
StevenWdV committed
121 122


123
# TODO more timeouts?
124
# TODO fast retransmit
125 126 127
class _Stream:
    def __init__(self, binding: "Binding", local_port: int,
                 remote_port: int, remote_udp_addr: Any,
StevenWdV's avatar
StevenWdV committed
128
                 local_window_size: int, timeout: float,
129 130 131 132 133 134 135 136
                 remote_init: Optional[_RemoteInit] = None):
        self.binding = binding
        self.local_port = local_port
        self.remote_port = remote_port
        self.local_window_size = local_window_size
        self.remote_udp_addr = remote_udp_addr
        self.timeout = timeout
        self.expect_syn_ack = remote_init is None
StevenWdV's avatar
StevenWdV committed
137

138 139
        self.send_buffer: List[bytes] = []  # Data to be send
        self.sent_ack_buffer: List[_TimestampPacket] = []  # Data still to be ACKed by other
StevenWdV's avatar
StevenWdV committed
140

141 142
        self.receive_ack_buffer: List[_Packet] = []  # Received but not ACKed by me
        self.receive_buffer: List[bytes] = []  # Received and ACKed, but not yet retrieved
StevenWdV's avatar
StevenWdV committed
143

144
        self.send_buffer_lock = threading.Lock()
145
        self.send_buffer_empty = threading.Condition(self.send_buffer_lock)
146 147
        self.receive_buffer_lock = threading.Lock()
        self.receive_bytes_available = threading.Condition(self.receive_buffer_lock)
StevenWdV's avatar
StevenWdV committed
148

149 150 151 152 153 154 155 156 157
        self.all_acked_lock = threading.Lock()
        self.all_acked = threading.Condition(self.all_acked_lock)

        self.receiver_closed = False  # If the remote host has sent a FIN
        self.sender_closed = False  # If we have sent a FIN

        self.receiver_closed_lock = threading.Lock()
        self.receiver_closed_event = threading.Condition(self.receiver_closed_lock)

StevenWdV's avatar
StevenWdV committed
158 159
        self.deb = f"{self.binding.local_udp_addr}|{self.local_port}->{self.remote_port}: "

160
        self.first_send_syn_nr = random.randint(0, max_seq)
StevenWdV's avatar
StevenWdV committed
161 162 163

        if remote_init is None:  # We initiate the connection
            self.send_syn_nr = self.first_send_syn_nr  # Next to be sent
164
            self.send_ack_nr = self.send_syn_nr  # First not ACKed by other
StevenWdV's avatar
StevenWdV committed
165

166 167
            self.first_recv_syn_nr = None

168 169 170
            self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, False, False)))

        else:
StevenWdV's avatar
StevenWdV committed
171 172
            self.remote_window_size = remote_init.window_size

173
            self.recv_syn_nr = remote_init.syn_nr  # Last received
174
            self.recv_ack_nr = self.__next_seq(remote_init.syn_nr)  # First still to be received
StevenWdV's avatar
StevenWdV committed
175
            self.send_syn_nr = self.first_send_syn_nr  # Next to be sent
176 177
            self.send_ack_nr = self.send_syn_nr  # First not ACKed by other

StevenWdV's avatar
StevenWdV committed
178 179
            self.first_recv_syn_nr = self.recv_syn_nr

180 181
            self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, True, False)))

182
        self.send_syn_nr = self.__next_seq(self.send_syn_nr)
StevenWdV's avatar
StevenWdV committed
183
        logging.debug(self.deb + "Setup stream")
184

185 186 187 188 189 190 191 192 193
    @staticmethod
    def __next_seq(nr: int) -> int:
        return (nr + 1) % max_seq

    @staticmethod
    def __seq_between(nr: int, lower_bound: int, upper_bound: int) -> bool:
        nr %= max_seq
        lower_bound %= max_seq
        upper_bound %= max_seq
194
        if lower_bound <= upper_bound:
195 196 197 198
            return lower_bound <= nr <= upper_bound
        else:
            return nr >= lower_bound or nr <= upper_bound

199
    def __send_to_be_acked(self, data: bytes, syn_nr: int, flags=_Flags((False, True, False))) -> None:
StevenWdV's avatar
StevenWdV committed
200 201 202
        logging.debug(self.deb +
                      f"Sending #{syn_nr} {flags}{f' with ACK {self.recv_ack_nr}' if flags.ack else ''} "
                      f"with {len(data)} bytes")
203
        header = _Header(self.local_port, self.remote_port, syn_nr, self.recv_ack_nr if flags.ack else 0,
StevenWdV's avatar
StevenWdV committed
204 205
                         int(flags), self.local_window_size, len(data))
        packet = _TimestampPacket(header, data, time.perf_counter())
206 207 208 209 210 211
        packet.set_checksum()
        self.sent_ack_buffer.append(packet)
        # noinspection PyProtectedMember
        self.binding._send(packet, self.remote_udp_addr)

    def __send_ack(self) -> None:
StevenWdV's avatar
StevenWdV committed
212
        logging.debug(self.deb + f"Sending ACK {self.recv_ack_nr}")
213
        ack_header = _Header(self.local_port, self.remote_port, self.send_syn_nr, self.recv_ack_nr,
StevenWdV's avatar
StevenWdV committed
214
                             int(_Flags((False, True, False))), self.local_window_size, 0)
215 216 217 218 219 220 221

        ack_packet = _Packet(ack_header, b"")
        ack_packet.set_checksum()
        # noinspection PyProtectedMember
        self.binding._send(ack_packet, self.remote_udp_addr)

    def pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
StevenWdV's avatar
StevenWdV committed
222
        logging.debug(self.deb + f"Received {len(packets)} messages")
223 224 225 226
        send_ack = False

        for packet in packets:
            flags = packet.header.flags_obj()
StevenWdV's avatar
StevenWdV committed
227 228 229
            logging.debug(self.deb + f"Received #{packet.header.syn_nr} {flags}"
                          f"{f' with ACK {packet.header.ack_nr}' if flags.ack else ''} "
                          f"with {len(packet.data)} bytes")
230 231 232

            if self.expect_syn_ack:
                if flags.syn and flags.ack and not flags.fin:
233
                    if packet.header.ack_nr != self.__next_seq(self.send_ack_nr):
StevenWdV's avatar
StevenWdV committed
234
                        logging.warning(self.deb + "Wrong ACK nr in SYN-ACK")
235 236 237
                        continue
                    self.expect_syn_ack = False
                    self.recv_syn_nr = packet.header.syn_nr
238
                    self.recv_ack_nr = self.__next_seq(packet.header.syn_nr)
StevenWdV's avatar
StevenWdV committed
239
                    self.first_recv_syn_nr = self.recv_syn_nr
240
                    self.remote_window_size = packet.header.window
241
                    send_ack = True
StevenWdV's avatar
StevenWdV committed
242
                else:
StevenWdV's avatar
StevenWdV committed
243
                    logging.warning(self.deb + "SYN-ACK expected")
244 245
                    continue
            elif flags.syn:
246
                if self.first_recv_syn_nr is not None and packet.header.syn_nr == self.first_recv_syn_nr:
247
                    logging.warning(self.deb + "Spurious SYN")
StevenWdV's avatar
StevenWdV committed
248 249 250
                    continue
                else:
                    logging.warning(self.deb + "Unexpected SYN")
251
                    # TODO simultaneous open if self.first_recv_syn_nr is None
StevenWdV's avatar
StevenWdV committed
252
                    continue
StevenWdV's avatar
StevenWdV committed
253

254 255
            if not self.__seq_between(packet.header.syn_nr, self.recv_ack_nr - self.local_window_size,
                                      self.recv_ack_nr + self.local_window_size - 1):
256
                logging.warning(self.deb + "SYN nr outside of window (spurious retransmission?)")
257 258
                continue

259 260 261
            if self.__seq_between(packet.header.syn_nr, self.recv_ack_nr + 1,
                                  self.recv_ack_nr + self.local_window_size - 1):
                self.recv_syn_nr = packet.header.syn_nr
StevenWdV's avatar
StevenWdV committed
262

263
            if flags.ack:
264
                if self.__seq_between(packet.header.ack_nr, self.send_ack_nr + 1, self.send_syn_nr):
265
                    # print(f"BEFORE ACK: {', '.join([str(packet.header.syn_nr) for packet in self.sent_ack_buffer])}")
266
                    # Delete ACKed messages
267 268
                    self.sent_ack_buffer = [
                        p for p in self.sent_ack_buffer
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
                        if not self.__seq_between(p.header.syn_nr, self.send_ack_nr, packet.header.ack_nr - 1)]
                    self.send_ack_nr = packet.header.ack_nr
                    # print(f"AFTER ACK: {', '.join([str(packet.header.syn_nr) for packet in self.sent_ack_buffer])}")
                    if self.send_ack_nr == self.send_syn_nr:
                        self.all_acked_lock.acquire()
                        self.all_acked.notify()
                        self.all_acked_lock.release()

            if flags.fin:
                self.receiver_closed_lock.acquire()
                self.receiver_closed = True
                self.receiver_closed_event.notify()
                self.receiver_closed_lock.release()

            if packet.header.data_length > 0 or flags.fin:
284 285 286
                if packet.header.syn_nr == self.recv_ack_nr:
                    # Move run of received packets to receive buffer
                    self.receive_buffer_lock.acquire()
287 288
                    if packet.header.data_length > 0:
                        self.receive_buffer.append(packet.data)
289
                    self.recv_ack_nr = self.__next_seq(self.recv_ack_nr)
290 291 292 293 294 295 296

                    index = bisect.bisect_left(self.receive_ack_buffer, packet)
                    while (index < len(self.receive_ack_buffer)
                           and self.receive_ack_buffer[index].header.syn_nr == self.recv_ack_nr):
                        packet = self.receive_ack_buffer.pop(index)
                        if packet.header.data_length > 0:
                            self.receive_buffer.append(packet.data)
297
                        self.recv_ack_nr = self.__next_seq(self.recv_ack_nr)
298 299 300
                        if len(self.receive_ack_buffer) > 0:
                            index %= len(self.receive_ack_buffer)

301 302 303
                    self.receive_bytes_available.notify()
                    self.receive_buffer_lock.release()

304 305
                elif self.__seq_between(packet.header.syn_nr, self.recv_ack_nr + 1,
                                        self.recv_ack_nr + self.local_window_size - 1):
306 307
                    # There is a gap, store this packet
                    insert_index = bisect.bisect_left(self.receive_ack_buffer, packet)
StevenWdV's avatar
StevenWdV committed
308

309 310 311
                    if (insert_index < len(self.receive_ack_buffer)
                            and self.receive_ack_buffer[insert_index].header.syn_nr == packet.header.syn_nr):
                        assert self.receive_ack_buffer[insert_index].data == packet.data
312
                        logging.warning(self.deb + "Spurious retransmission of packet in window")
StevenWdV's avatar
StevenWdV committed
313
                    else:
314
                        self.receive_ack_buffer.insert(insert_index, packet)
315
                else:
316
                    logging.warning(self.deb + "Spurious retransmission of packet before window")
317 318 319 320 321 322 323 324

                send_ack = True

        if send_ack:
            # Don't send ACK now if we can piggyback it on the data
            self.send_buffer_lock.acquire()
            data_to_send = len(self.send_buffer) > 0
            self.send_buffer_lock.release()
StevenWdV's avatar
StevenWdV committed
325
            if not data_to_send:
326 327
                self.__send_ack()

328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
        if debug_window:
            print(self.deb + f"ACK {self.recv_ack_nr}")
            syn = self.recv_ack_nr
            if len(self.receive_ack_buffer) > 0:
                index = bisect.bisect_left(self.receive_ack_buffer, _Packet(_Header(0, 0, syn, 0, 0, 0, 0, 0), b""))
                while syn < self.recv_ack_nr + self.local_window_size:
                    while (self.receive_ack_buffer[index].header.syn_nr != syn
                           < self.recv_ack_nr + self.local_window_size):
                        print(".", end="")
                        syn += 1
                    while (self.receive_ack_buffer[index].header.syn_nr == syn
                           < self.recv_ack_nr + self.local_window_size):
                        print("R", end="")
                        syn += 1
                        index = (index + 1) % len(self.receive_ack_buffer)
                print()
            else:
                print("." * self.local_window_size)

347
    def poll_sender(self) -> None:
348 349 350 351 352 353
        curtime = time.perf_counter()
        lost = [p for p in self.sent_ack_buffer if curtime - p.timestamp >= self.timeout]
        self.sent_ack_buffer = [p for p in self.sent_ack_buffer if curtime - p.timestamp < self.timeout]

        # print(f"LOST: {', '.join([str(packet.header.syn_nr) for packet in lost])}")
        # print(f"NOT LOST: {', '.join([str(packet.header.syn_nr) for packet in self.sent_ack_buffer])}")
StevenWdV's avatar
StevenWdV committed
354
        if len(lost) > 0:
355
            logging.warning(self.deb + f"{len(lost)} lost messages")
StevenWdV's avatar
StevenWdV committed
356 357 358
        for p in lost:
            p.timestamp = time.perf_counter()
            self.__send_to_be_acked(p.data, p.header.syn_nr, p.header.flags_obj())
359
        # print(f"SENT to ACK: {', '.join([str(packet.header.syn_nr) for packet in self.sent_ack_buffer])}\n")
StevenWdV's avatar
StevenWdV committed
360

361 362 363
        if self.expect_syn_ack:
            return

StevenWdV's avatar
StevenWdV committed
364 365
        if len(self.send_buffer) > 0:
            logging.debug(self.deb + f"{len(self.send_buffer)} unsent messages")
366 367 368

        if len(self.sent_ack_buffer) < self.remote_window_size:
            self.send_buffer_lock.acquire()
StevenWdV's avatar
StevenWdV committed
369
            while len(self.send_buffer) > 0 and len(self.sent_ack_buffer) < self.remote_window_size:
370
                data = self.send_buffer.pop(0)
StevenWdV's avatar
StevenWdV committed
371
                self.__send_to_be_acked(data, self.send_syn_nr)
372
                self.send_syn_nr = self.__next_seq(self.send_syn_nr)
373 374
            if len(self.send_buffer) == 0:
                self.send_buffer_empty.notify()
375 376 377
            self.send_buffer_lock.release()

    def enqueue_data(self, data: bytes) -> None:
StevenWdV's avatar
StevenWdV committed
378
        logging.debug(self.deb + f"Enqueuing {len(data)} bytes")
379
        assert not self.sender_closed
380 381 382
        self.send_buffer_lock.acquire()
        offset = 0
        while offset < len(data):
383
            self.send_buffer.append(data[offset: min(offset + payload_size, len(data))])
384 385 386 387 388 389
            offset += payload_size
        self.send_buffer_lock.release()

    def __at_least_received(self, count: int) -> bool:
        i = 0
        while i < len(self.receive_buffer) and count > 0:
StevenWdV's avatar
StevenWdV committed
390
            count -= len(self.receive_buffer[i])
391
            i += 1
392
        logging.debug(self.deb + f"{count} bytes short")
393 394
        return count <= 0

395
    def dequeue_data_array(self, count: int) -> List[bytes]:
StevenWdV's avatar
StevenWdV committed
396
        logging.debug(self.deb + f"Dequeuing {count} bytes...")
397 398 399
        self.receive_buffer_lock.acquire()
        self.receive_bytes_available.wait_for(lambda: self.__at_least_received(count))

400 401 402 403 404
        data = []
        bytes_received = 0
        while bytes_received + len(self.receive_buffer[0]) < count:
            bytes_received += len(self.receive_buffer[0])
            data.append(self.receive_buffer.pop(0))
405

406 407 408
        if bytes_received < count:
            bytes_short = count - bytes_received
            data.append(self.receive_buffer[0][:bytes_short])
StevenWdV's avatar
StevenWdV committed
409
            self.receive_buffer[0] = self.receive_buffer[0][bytes_short:]
410 411

        self.receive_buffer_lock.release()
StevenWdV's avatar
StevenWdV committed
412
        logging.debug(self.deb + "Dequeued")
413 414
        return data

415 416 417
    def dequeue_data(self, count: int) -> bytes:
        return b"".join(self.dequeue_data_array(count))

418 419 420 421 422 423 424
    def dequeue_is_available(self, count: int) -> bool:
        self.receive_buffer_lock.acquire()
        available = self.__at_least_received(count)
        self.receive_buffer_lock.release()
        return available

    def dequeue_all_data(self) -> bytes:
StevenWdV's avatar
StevenWdV committed
425
        logging.debug(self.deb + "Dequeuing all data...")
426 427 428 429 430
        self.receive_buffer_lock.acquire()
        data = b""
        while len(self.receive_buffer) > 0:
            data += self.receive_buffer.pop(0)
        self.receive_buffer_lock.release()
StevenWdV's avatar
StevenWdV committed
431
        logging.debug(self.deb + f"Dequeued {len(data)} bytes")
432 433
        return data

434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
    def close_sender(self) -> None:
        logging.debug(self.deb + "Close sender")
        if self.sender_closed:
            logging.debug(self.deb + "Sender is already closed")
            return

        logging.debug(self.deb + "Waiting until send_buffer is empty")
        self.send_buffer_lock.acquire()
        self.send_buffer_empty.wait_for(lambda: len(self.send_buffer) == 0)
        self.send_buffer_lock.release()

        logging.debug(self.deb + "Waiting until sent_ack_buffer is empty")
        self.all_acked_lock.acquire()
        self.all_acked.wait_for(lambda: self.send_ack_nr == self.send_syn_nr)
        self.all_acked_lock.release()

        self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((False, True, True)))
        self.send_syn_nr = self.__next_seq(self.send_syn_nr)
        self.sender_closed = True

    def close_receiver(self) -> None:
        logging.debug(self.deb + "Waiting for remote host to close")
        self.receiver_closed_lock.acquire()
        self.receiver_closed_event.wait_for(lambda: self.receiver_closed)
        self.receiver_closed_lock.release()
459 460 461 462 463 464


class _Connection:
    def __init__(self, parent: Union["_Server", "Binding"], stream: _Stream):
        self.parent = parent
        self.stream = stream
StevenWdV's avatar
StevenWdV committed
465 466 467 468 469
        self.deb = f"{self.__parent_str()}|{self.stream.local_port}->{self.stream.remote_port}: "
        logging.debug(self.deb + "Setup connection")

    def __parent_str(self) -> str:
        return self.parent.local_udp_addr if type(self.parent) is Binding else self.parent.binding.local_udp_addr
470 471 472 473 474 475 476 477 478 479

    def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
        self.stream.pass_received_msgs(packets)

    def _poll_streams(self) -> None:
        self.stream.poll_sender()

    def send(self, data: bytes) -> None:
        self.stream.enqueue_data(data)

480 481 482
    def receive_array(self, count: int) -> List[bytes]:
        return self.stream.dequeue_data_array(count)

483 484 485 486 487 488 489 490 491
    def receive(self, count: int) -> bytes:
        return self.stream.dequeue_data(count)

    def receive_is_available(self, count: int) -> bool:
        return self.stream.dequeue_is_available(count)

    def receive_all(self) -> bytes:
        return self.stream.dequeue_all_data()

492 493 494 495
    def close_sender(self) -> None:
        logging.debug(self.deb + "Close sender")
        self.stream.close_sender()

496
    def close(self) -> None:
497 498 499
        logging.debug(self.deb + "Closing connection...")
        self.stream.close_sender()
        self.stream.close_receiver()
500 501 502 503 504 505
        if type(self.parent) is _Server:
            # noinspection PyProtectedMember
            self.parent._remove_stream(self.stream.remote_port)
        else:
            # noinspection PyProtectedMember
            self.parent._remove_socket(self.stream.remote_port)
StevenWdV's avatar
StevenWdV committed
506
        logging.debug(self.deb + "Closed")
507 508 509 510 511 512


class _Server:
    def __init__(self, binding: "Binding", local_port: int):
        self.binding = binding
        self.local_port = local_port
StevenWdV's avatar
StevenWdV committed
513
        self.streams: Dict[int, _Stream] = {}  # Remote port -> _Stream
514 515 516 517 518
        self.streams_lock = threading.Lock()

        self.backlog = 0
        self.backlog_connections: List[_Connection] = []
        self.backlog_lock = threading.Lock()
StevenWdV's avatar
StevenWdV committed
519 520 521 522
        self.backlog_available = threading.Condition(self.backlog_lock)

        self.deb = f"{self.binding.local_udp_addr}|{self.local_port}: "
        logging.debug(self.deb + "Setup server")
523 524

    def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
StevenWdV's avatar
StevenWdV committed
525 526
        logging.debug(self.deb + f"Received {len(packets)} messages")

527 528 529 530 531 532 533
        packet_batches: Dict[int, List[_AddrPacket]] = {}
        orphan_packets: List[_AddrPacket] = []

        self.streams_lock.acquire()

        for packet in packets:
            if packet.header.src_port in self.streams:
StevenWdV's avatar
StevenWdV committed
534
                packet_batches.setdefault(packet.header.src_port, []).append(packet)
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549
            else:
                orphan_packets.append(packet)

        for remote_port, packets in packet_batches.items():
            self.streams[remote_port].pass_received_msgs(packets)

        for packet in orphan_packets:
            flags = packet.header.flags_obj()
            if (flags.syn and not flags.ack and not flags.fin
                    and packet.header.data_length == 0):
                self.backlog_lock.acquire()
                if self.backlog > 0:
                    self.streams[packet.header.src_port] = stream = _Stream(
                        self.binding, self.local_port,
                        packet.header.src_port, packet.remote_udp_addr,
550
                        self.binding.window_size, self.binding.timeout_sec,
551 552 553
                        _RemoteInit(packet.header.syn_nr, packet.header.window))
                    self.backlog_connections.append(_Connection(self, stream))
                    self.backlog -= 1
StevenWdV's avatar
StevenWdV committed
554
                    self.backlog_available.notify()
555 556 557
                    self.backlog_lock.release()
                else:
                    self.backlog_lock.release()
StevenWdV's avatar
StevenWdV committed
558
                    logging.warning(self.deb + "Backlog full")
559
            else:
StevenWdV's avatar
StevenWdV committed
560
                logging.warning(self.deb + "Unknown connection")
561 562 563 564 565 566 567 568 569 570

        self.streams_lock.release()

    def _poll_streams(self) -> None:
        self.streams_lock.acquire()
        for stream in self.streams.values():
            stream.poll_sender()
        self.streams_lock.release()

    def _remove_stream(self, remote_port: int) -> None:
StevenWdV's avatar
StevenWdV committed
571
        logging.debug(self.deb + "Remove stream")
572 573 574 575 576
        self.streams_lock.acquire()
        del self.streams[remote_port]
        self.streams_lock.release()

    def start_listen(self, backlog: int) -> None:
StevenWdV's avatar
StevenWdV committed
577
        logging.debug(self.deb + "Start listen")
578 579 580 581 582
        self.backlog_lock.acquire()
        self.backlog = backlog
        self.backlog_lock.release()

    def accept(self) -> _Connection:
StevenWdV's avatar
StevenWdV committed
583
        logging.debug(self.deb + "Accepting...")
584
        self.backlog_lock.acquire()
StevenWdV's avatar
StevenWdV committed
585
        self.backlog_available.wait_for(lambda: len(self.backlog_connections) > 0)
586 587
        connection = self.backlog_connections.pop(0)
        self.backlog_lock.release()
StevenWdV's avatar
StevenWdV committed
588
        logging.debug(self.deb + "Accepted")
589 590 591
        return connection

    def close(self) -> None:
StevenWdV's avatar
StevenWdV committed
592
        logging.debug(self.deb + "Closing...")
593 594 595 596
        while len(self.streams) > 0:
            self.streams[0].close_sender()
            self.streams[0].close_receiver()
            self.streams.pop(0)
597 598
        # noinspection PyProtectedMember
        self.binding._remove_socket(self.local_port)
StevenWdV's avatar
StevenWdV committed
599
        logging.debug(self.deb + "Closed")
600 601 602 603


class Binding:
    def __init__(self, protocol: int, local_udp_addr: Optional[Any],
604
                 window_size: int, timeout_ms: int, poll_time_ms: Optional[int] = None):
StevenWdV's avatar
StevenWdV committed
605
        self.local_udp_addr = local_udp_addr
606
        self.window_size = window_size
607
        self.timeout_sec = timeout_ms / 1000
608
        self.poll_time_ms = poll_time_ms or timeout_ms / 2
609 610 611 612 613

        self.sock = socket.socket(protocol, socket.SOCK_DGRAM)
        if local_udp_addr is not None:
            self.sock.bind(local_udp_addr)

StevenWdV's avatar
StevenWdV committed
614
        self.sockets: Dict[int, Union[_Server, _Connection]] = {}  # Local port -> _Server / _Connection
615
        self.stop = False
616
        self.sockets_stop_lock = threading.Lock()
617 618

        self.read_thread = threading.Thread(None, self.__background)
StevenWdV's avatar
StevenWdV committed
619 620
        self.read_thread.start()

StevenWdV's avatar
StevenWdV committed
621 622 623
        self.deb = f"{self.local_udp_addr}: "
        logging.debug(self.deb + "Set up binding")

StevenWdV's avatar
StevenWdV committed
624
    # TODO use more threads to avoid poll
625
    def __background(self) -> None:
StevenWdV's avatar
StevenWdV committed
626
        poller = select.poll()  # Does not work on Windows because Python is stupid (WSAPoll is a thing)
627 628 629
        poller.register(self.sock, select.POLLIN)

        while True:
630
            poll_result = poller.poll(self.poll_time_ms)
631 632 633 634 635 636 637

            packet_batches: Dict[int, List[_AddrPacket]] = {}

            self.sockets_stop_lock.acquire()
            if self.stop:
                break

StevenWdV's avatar
StevenWdV committed
638
            while len(poll_result) > 0:
639
                data, addr = self.sock.recvfrom(header_size + payload_size, socket.MSG_WAITALL)
StevenWdV's avatar
StevenWdV committed
640
                logging.debug(self.deb + f"Packet from {addr}")
641
                packet = _AddrPacket(_Header.unpack(data[:header_size]),
StevenWdV's avatar
StevenWdV committed
642
                                     data[header_size:], addr)
643 644
                if packet.header.data_length > payload_size:
                    logging.warning(self.deb + "data_length too large")
645
                packet.data = packet.data[:packet.header.data_length]
646 647

                if not packet.verify_checksum():
StevenWdV's avatar
StevenWdV committed
648
                    logging.warning(self.deb + "Invalid checksum")
649
                else:
StevenWdV's avatar
StevenWdV committed
650 651
                    if packet.header.dest_port in self.sockets:
                        packet_batches.setdefault(packet.header.dest_port, []).append(packet)
652
                    else:
StevenWdV's avatar
StevenWdV committed
653
                        logging.warning(self.deb + "Unknown server")
654 655 656 657 658 659 660 661 662 663 664 665 666

                poll_result = poller.poll(0)

            for local_port, packets in packet_batches.items():
                # noinspection PyProtectedMember
                self.sockets[local_port]._pass_received_msgs(packets)

            for server in self.sockets.values():
                # noinspection PyProtectedMember
                server._poll_streams()

            self.sockets_stop_lock.release()

StevenWdV's avatar
StevenWdV committed
667 668
        self.sockets_stop_lock.release()

669
    def _send(self, packet: _Packet, remote_udp_addr: Any) -> None:
StevenWdV's avatar
StevenWdV committed
670
        logging.debug(self.deb + f"Send to {remote_udp_addr}")
671
        data = bytes(packet.header) + packet.data + bytes(payload_size - len(packet.data))
672 673 674 675
        while len(data) > 0:
            data = data[self.sock.sendto(data, remote_udp_addr):]

    def _remove_socket(self, local_port: int) -> None:
StevenWdV's avatar
StevenWdV committed
676
        logging.debug(self.deb + f"Remove socket {local_port}")
677 678 679 680 681
        self.sockets_stop_lock.acquire()
        del self.sockets[local_port]
        self.sockets_stop_lock.release()

    def bind_server(self, local_btcp_port: int) -> _Server:
StevenWdV's avatar
StevenWdV committed
682
        logging.debug(self.deb + f"Bind server to btcp {local_btcp_port}")
683 684 685 686 687 688
        self.sockets_stop_lock.acquire()
        server = self.sockets[local_btcp_port] = _Server(self, local_btcp_port)
        self.sockets_stop_lock.release()
        return server

    def connect_client(self, local_btcp_port: int, remote_btcp_port: int, remote_udp_addr: Any) -> _Connection:
StevenWdV's avatar
StevenWdV committed
689
        logging.debug(self.deb + f"Connect to btcp {local_btcp_port} -> {remote_btcp_port} {remote_udp_addr}")
690 691
        self.sockets_stop_lock.acquire()
        connection = _Connection(self, _Stream(self, local_btcp_port, remote_btcp_port, remote_udp_addr,
692
                                               self.window_size, self.timeout_sec))
693 694 695 696 697
        self.sockets[local_btcp_port] = connection
        self.sockets_stop_lock.release()
        return connection

    def close(self) -> None:
StevenWdV's avatar
StevenWdV committed
698
        logging.debug(self.deb + "Closing...")
699 700 701 702 703

        while len(self.sockets) > 0:
            self.sockets[0].close()
        self.sock.close()

704 705 706 707 708
        self.sockets_stop_lock.acquire()
        self.stop = True
        self.sockets_stop_lock.release()

        self.read_thread.join()
StevenWdV's avatar
StevenWdV committed
709
        logging.debug(self.deb + "Closed")