btcp.py 22.8 KB
Newer Older
StevenWdV's avatar
StevenWdV committed
1
import binascii
2
import bisect
StevenWdV's avatar
StevenWdV committed
3
import copy
StevenWdV's avatar
StevenWdV committed
4
import logging
StevenWdV's avatar
StevenWdV committed
5
import random
6
7
8
9
10
11
import select
import socket
import struct
import threading
import time
from typing import *
StevenWdV's avatar
StevenWdV committed
12

13
header_format = "!HHHHBBHI"
StevenWdV's avatar
StevenWdV committed
14
header_size = 16
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
payload_size = 1000


class _Flags:
    def __init__(self, flags: Union[int, Tuple[bool, bool, bool]] = 0):
        """
        :param flags: Combined or (SYN, ACK, FIN)
        """
        if type(flags) is int:
            self.fin = bool(flags & 1)
            self.ack = bool(flags & 1 << 1)
            self.syn = bool(flags & 1 << 2)
        else:
            self.syn, self.ack, self.fin = flags

StevenWdV's avatar
StevenWdV committed
30
    def __int__(self) -> int:
31
        return self.syn << 2 | self.ack << 1 | self.fin
StevenWdV's avatar
StevenWdV committed
32

StevenWdV's avatar
StevenWdV committed
33
34
35
36
37
38
39
40
41
42
    def __str__(self) -> str:
        strs: List[str] = []
        if self.syn:
            strs.append("SYN")
        if self.ack:
            strs.append("ACK")
        if self.fin:
            strs.append("FIN")
        return "-".join(strs)

StevenWdV's avatar
StevenWdV committed
43

44
45
46
47
48
49
50
class _Header:
    def __init__(self, src_port: int, dest_port: int, syn_nr: int, ack_nr: int, flags: int,
                 window: int, data_length: int, checksum: int = 0):
        self.dest_port = dest_port
        self.src_port = src_port
        self.syn_nr = syn_nr
        self.ack_nr = ack_nr
StevenWdV's avatar
StevenWdV committed
51
52
        self.flags = flags
        self.window = window
53
        self.data_length = data_length
StevenWdV's avatar
StevenWdV committed
54
55
        self.checksum = checksum

56
57
    def flags_obj(self) -> _Flags:
        return _Flags(self.flags)
StevenWdV's avatar
StevenWdV committed
58
59


60
61
62
63
64
65
66
67
68
69
70
def _pack_header(header: _Header) -> bytes:
    return struct.pack(
        header_format,
        header.src_port,
        header.dest_port,
        header.syn_nr,
        header.ack_nr,
        header.flags,
        header.window,
        header.data_length,
        header.checksum)
StevenWdV's avatar
StevenWdV committed
71
72


73
74
def _unpack_header(header: bytes) -> _Header:
    return _Header(*struct.unpack(header_format, header))
StevenWdV's avatar
StevenWdV committed
75

StevenWdV's avatar
StevenWdV committed
76

77
78
class _Packet:
    def __init__(self, header: _Header, data: bytes):
StevenWdV's avatar
StevenWdV committed
79
80
        self.header = header
        self.data = data
StevenWdV's avatar
StevenWdV committed
81

StevenWdV's avatar
StevenWdV committed
82
    def compute_checksum(self) -> int:
83
84
85
86
87
88
        if self.header.checksum == 0:
            header_no_checksum = self.header
        else:
            header_no_checksum = copy.deepcopy(self.header)
            header_no_checksum.checksum = 0
        return binascii.crc32(_pack_header(header_no_checksum) + self.data)
StevenWdV's avatar
StevenWdV committed
89

StevenWdV's avatar
StevenWdV committed
90
91
92
    def verify_checksum(self) -> bool:
        return self.compute_checksum() == self.header.checksum

93
    def set_checksum(self) -> None:
StevenWdV's avatar
StevenWdV committed
94
        self.header.checksum = self.compute_checksum()
StevenWdV's avatar
StevenWdV committed
95

96
97
    def __lt__(self, other) -> bool:
        return self.header.syn_nr < other.header.synNr
StevenWdV's avatar
StevenWdV committed
98

StevenWdV's avatar
StevenWdV committed
99

100
101
102
103
class _TimestampPacket(_Packet):
    def __init__(self, header: _Header, data: bytes, timestamp: float):
        super().__init__(header, data)
        self.timestamp = timestamp
StevenWdV's avatar
StevenWdV committed
104
105


106
107
108
109
class _AddrPacket(_Packet):
    def __init__(self, header: _Header, data: bytes, remote_udp_addr: Tuple[str, int]):
        super().__init__(header, data)
        self.remote_udp_addr = remote_udp_addr
StevenWdV's avatar
StevenWdV committed
110
111


StevenWdV's avatar
StevenWdV committed
112
# Info for if the connection was not initiated from our side
113
114
115
116
class _RemoteInit:
    def __init__(self, syn_nr: int, window_size: int):
        self.syn_nr = syn_nr
        self.window_size = window_size
StevenWdV's avatar
StevenWdV committed
117
118


119
120
121
class _Stream:
    def __init__(self, binding: "Binding", local_port: int,
                 remote_port: int, remote_udp_addr: Any,
StevenWdV's avatar
StevenWdV committed
122
                 local_window_size: int, timeout: float,
123
124
125
126
127
128
129
130
                 remote_init: Optional[_RemoteInit] = None):
        self.binding = binding
        self.local_port = local_port
        self.remote_port = remote_port
        self.local_window_size = local_window_size
        self.remote_udp_addr = remote_udp_addr
        self.timeout = timeout
        self.expect_syn_ack = remote_init is None
StevenWdV's avatar
StevenWdV committed
131

132
133
        self.send_buffer: List[bytes] = []  # Data to be send
        self.sent_ack_buffer: List[_TimestampPacket] = []  # Data still to be ACKed by other
StevenWdV's avatar
StevenWdV committed
134

135
136
        self.receive_ack_buffer: List[_Packet] = []  # Received but not ACKed by me
        self.receive_buffer: List[bytes] = []  # Received and ACKed, but not yet retrieved
StevenWdV's avatar
StevenWdV committed
137

138
139
140
        self.send_buffer_lock = threading.Lock()
        self.receive_buffer_lock = threading.Lock()
        self.receive_bytes_available = threading.Condition(self.receive_buffer_lock)
StevenWdV's avatar
StevenWdV committed
141

StevenWdV's avatar
StevenWdV committed
142
143
144
145
146
147
        self.deb = f"{self.binding.local_udp_addr}|{self.local_port}->{self.remote_port}: "

        self.first_send_syn_nr = random.randint(0, 0xffFF)

        if remote_init is None:  # We initiate the connection
            self.send_syn_nr = self.first_send_syn_nr  # Next to be sent
148
            self.send_ack_nr = self.send_syn_nr  # First not ACKed by other
StevenWdV's avatar
StevenWdV committed
149

150
151
            self.first_recv_syn_nr = None

152
153
154
            self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, False, False)))

        else:
StevenWdV's avatar
StevenWdV committed
155
156
            self.remote_window_size = remote_init.window_size

157
158
            self.recv_syn_nr = remote_init.syn_nr  # Last received
            self.recv_ack_nr = remote_init.syn_nr + 1  # First still to be received
StevenWdV's avatar
StevenWdV committed
159
            self.send_syn_nr = self.first_send_syn_nr  # Next to be sent
160
161
            self.send_ack_nr = self.send_syn_nr  # First not ACKed by other

StevenWdV's avatar
StevenWdV committed
162
163
            self.first_recv_syn_nr = self.recv_syn_nr

164
165
166
            self.__send_to_be_acked(b"", self.send_syn_nr, _Flags((True, True, False)))

        self.send_syn_nr += 1
StevenWdV's avatar
StevenWdV committed
167
        logging.debug(self.deb + "Setup stream")
168
169

    def __send_to_be_acked(self, data: bytes, syn_nr: int, flags=_Flags((False, True, False))) -> None:
StevenWdV's avatar
StevenWdV committed
170
171
172
        logging.debug(self.deb +
                      f"Sending #{syn_nr} {flags}{f' with ACK {self.recv_ack_nr}' if flags.ack else ''} "
                      f"with {len(data)} bytes")
173
174
        self.send_syn_nr = max(syn_nr, self.send_syn_nr)
        header = _Header(self.local_port, self.remote_port, syn_nr, self.recv_ack_nr if flags.ack else 0,
StevenWdV's avatar
StevenWdV committed
175
176
                         int(flags), self.local_window_size, len(data))
        packet = _TimestampPacket(header, data, time.perf_counter())
177
178
179
180
181
182
        packet.set_checksum()
        self.sent_ack_buffer.append(packet)
        # noinspection PyProtectedMember
        self.binding._send(packet, self.remote_udp_addr)

    def __send_ack(self) -> None:
StevenWdV's avatar
StevenWdV committed
183
        logging.debug(self.deb + f"Sending ACK {self.recv_ack_nr}")
184
        ack_header = _Header(self.local_port, self.remote_port, self.send_syn_nr, self.recv_ack_nr,
StevenWdV's avatar
StevenWdV committed
185
                             int(_Flags((False, True, False))), self.local_window_size, 0)
186
187
188
189
190
191
192

        ack_packet = _Packet(ack_header, b"")
        ack_packet.set_checksum()
        # noinspection PyProtectedMember
        self.binding._send(ack_packet, self.remote_udp_addr)

    def pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
StevenWdV's avatar
StevenWdV committed
193
        logging.debug(self.deb + f"Received {len(packets)} messages")
194
195
196
197
        send_ack = False

        for packet in packets:
            flags = packet.header.flags_obj()
StevenWdV's avatar
StevenWdV committed
198
199
200
            logging.debug(self.deb + f"Received #{packet.header.syn_nr} {flags}"
                          f"{f' with ACK {packet.header.ack_nr}' if flags.ack else ''} "
                          f"with {len(packet.data)} bytes")
201
202
203

            if self.expect_syn_ack:
                if flags.syn and flags.ack and not flags.fin:
StevenWdV's avatar
StevenWdV committed
204
205
                    if packet.header.ack_nr != self.send_ack_nr + 1:
                        logging.warning(self.deb + "Wrong ACK nr in SYN-ACK")
206
207
208
209
                        continue
                    self.expect_syn_ack = False
                    self.recv_syn_nr = packet.header.syn_nr
                    self.recv_ack_nr = packet.header.syn_nr + 1
StevenWdV's avatar
StevenWdV committed
210
                    self.first_recv_syn_nr = self.recv_syn_nr
211
                    self.remote_window_size = packet.header.window
StevenWdV's avatar
StevenWdV committed
212
                else:
StevenWdV's avatar
StevenWdV committed
213
                    logging.warning(self.deb + "SYN-ACK expected")
214
215
                    continue
            elif flags.syn:
216
                if self.first_recv_syn_nr is not None and packet.header.syn_nr == self.first_recv_syn_nr:
StevenWdV's avatar
StevenWdV committed
217
218
219
220
                    logging.debug(self.deb + "Spurious SYN")
                    continue
                else:
                    logging.warning(self.deb + "Unexpected SYN")
221
                    # TODO simultaneous open if self.first_recv_syn_nr is None
StevenWdV's avatar
StevenWdV committed
222
                    continue
StevenWdV's avatar
StevenWdV committed
223

224
            if packet.header.syn_nr > self.recv_syn_nr + self.local_window_size:
StevenWdV's avatar
StevenWdV committed
225
                logging.warning(self.deb + "Too high SYN")
226
                continue
StevenWdV's avatar
StevenWdV committed
227

228
229
230
231
            # TODO if flags.fin:

            if flags.ack:
                if packet.header.ack_nr > self.send_syn_nr + 1:
StevenWdV's avatar
StevenWdV committed
232
233
234
235
                    logging.warning(self.deb + "Too high ACK")
                    continue
                if packet.header.ack_nr < self.first_send_syn_nr:
                    logging.warning(self.deb + "Too low ACK")
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
                    continue
                if packet.header.ack_nr > self.send_ack_nr:
                    self.send_ack_nr = packet.header.ack_nr
                    # Delete ACKed messages
                    self.sent_ack_buffer = [p for p in self.sent_ack_buffer if p.header.syn_nr >= packet.header.ack_nr]

            self.recv_syn_nr = max(self.recv_syn_nr, packet.header.syn_nr)

            if packet.header.data_length > 0:
                if packet.header.syn_nr == self.recv_ack_nr:
                    # Move run of received packets to receive buffer
                    self.receive_buffer_lock.acquire()
                    self.receive_buffer.append(packet.data)
                    self.recv_ack_nr += 1
                    i = 0
                    while (i < len(self.receive_ack_buffer)
                           and self.receive_ack_buffer[i].header.syn_nr == self.recv_ack_nr):
                        self.receive_buffer.append(self.receive_ack_buffer.pop(i).data)
                        self.recv_ack_nr += 1
                        i += 1
                    self.receive_bytes_available.notify()
                    self.receive_buffer_lock.release()

                elif packet.header.syn_nr > self.recv_ack_nr:
                    # There is a gap, store this packet
                    insert_index = bisect.bisect_left(self.receive_ack_buffer, packet)
StevenWdV's avatar
StevenWdV committed
262
263
264
265
266
267

                    if (insert_index + 1 < len(self.receive_ack_buffer)
                            and self.receive_ack_buffer[insert_index + 1].header.syn_nr == packet.header.syn_nr):
                        assert self.receive_ack_buffer[insert_index + 1].data == packet.data
                        logging.debug(self.deb + "Spurious retransmission")
                    else:
268
269
270
271
272
273
274
275
276
                        self.receive_ack_buffer.insert(insert_index, packet)

                send_ack = True

        if send_ack:
            # Don't send ACK now if we can piggyback it on the data
            self.send_buffer_lock.acquire()
            data_to_send = len(self.send_buffer) > 0
            self.send_buffer_lock.release()
StevenWdV's avatar
StevenWdV committed
277
            if not data_to_send:
278
279
280
                self.__send_ack()

    def poll_sender(self) -> None:
StevenWdV's avatar
StevenWdV committed
281
282
283
284
285
286
287
288
        lost = [p for p in self.sent_ack_buffer if time.perf_counter() - p.timestamp >= self.timeout]
        self.sent_ack_buffer = [p for p in self.sent_ack_buffer if time.perf_counter() - p.timestamp < self.timeout]
        if len(lost) > 0:
            logging.debug(self.deb + f"{len(lost)} lost messages")
        for p in lost:
            p.timestamp = time.perf_counter()
            self.__send_to_be_acked(p.data, p.header.syn_nr, p.header.flags_obj())

289
290
291
        if self.expect_syn_ack:
            return

StevenWdV's avatar
StevenWdV committed
292
293
        if len(self.send_buffer) > 0:
            logging.debug(self.deb + f"{len(self.send_buffer)} unsent messages")
294
295
296

        if len(self.sent_ack_buffer) < self.remote_window_size:
            self.send_buffer_lock.acquire()
StevenWdV's avatar
StevenWdV committed
297
            while len(self.send_buffer) > 0 and len(self.sent_ack_buffer) < self.remote_window_size:
298
                data = self.send_buffer.pop(0)
StevenWdV's avatar
StevenWdV committed
299
                self.__send_to_be_acked(data, self.send_syn_nr)
300
301
302
303
                self.send_syn_nr += 1
            self.send_buffer_lock.release()

    def enqueue_data(self, data: bytes) -> None:
StevenWdV's avatar
StevenWdV committed
304
        logging.debug(self.deb + f"Enqueuing {len(data)} bytes")
305
306
307
308
309
310
311
312
313
314
        self.send_buffer_lock.acquire()
        offset = 0
        while offset < len(data):
            self.send_buffer.append(data[offset: min(payload_size, len(data) - offset)])
            offset += payload_size
        self.send_buffer_lock.release()

    def __at_least_received(self, count: int) -> bool:
        i = 0
        while i < len(self.receive_buffer) and count > 0:
StevenWdV's avatar
StevenWdV committed
315
            count -= len(self.receive_buffer[i])
316
317
318
319
            i += 1
        return count <= 0

    def dequeue_data(self, count: int) -> bytes:
StevenWdV's avatar
StevenWdV committed
320
        logging.debug(self.deb + f"Dequeuing {count} bytes...")
321
322
323
324
325
326
327
328
        self.receive_buffer_lock.acquire()
        self.receive_bytes_available.wait_for(lambda: self.__at_least_received(count))

        data = b""
        while len(data) + len(self.receive_buffer[0]) < count:
            data += self.receive_buffer.pop(0)

        if len(data) < count:
StevenWdV's avatar
StevenWdV committed
329
330
331
            bytes_short = count - len(data)
            data += self.receive_buffer[0][:bytes_short]
            self.receive_buffer[0] = self.receive_buffer[0][bytes_short:]
332
333

        self.receive_buffer_lock.release()
StevenWdV's avatar
StevenWdV committed
334
        logging.debug(self.deb + "Dequeued")
335
336
337
338
339
340
341
342
343
        return data

    def dequeue_is_available(self, count: int) -> bool:
        self.receive_buffer_lock.acquire()
        available = self.__at_least_received(count)
        self.receive_buffer_lock.release()
        return available

    def dequeue_all_data(self) -> bytes:
StevenWdV's avatar
StevenWdV committed
344
        logging.debug(self.deb + "Dequeuing all data...")
345
346
347
348
349
        self.receive_buffer_lock.acquire()
        data = b""
        while len(self.receive_buffer) > 0:
            data += self.receive_buffer.pop(0)
        self.receive_buffer_lock.release()
StevenWdV's avatar
StevenWdV committed
350
        logging.debug(self.deb + f"Dequeued {len(data)} bytes")
351
352
353
        return data

    def close(self) -> None:
StevenWdV's avatar
StevenWdV committed
354
        logging.debug(self.deb + "Close")
355
356
357
358
359
360
361
362
        # TODO FIN
        pass


class _Connection:
    def __init__(self, parent: Union["_Server", "Binding"], stream: _Stream):
        self.parent = parent
        self.stream = stream
StevenWdV's avatar
StevenWdV committed
363
364
365
366
367
        self.deb = f"{self.__parent_str()}|{self.stream.local_port}->{self.stream.remote_port}: "
        logging.debug(self.deb + "Setup connection")

    def __parent_str(self) -> str:
        return self.parent.local_udp_addr if type(self.parent) is Binding else self.parent.binding.local_udp_addr
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387

    def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
        self.stream.pass_received_msgs(packets)

    def _poll_streams(self) -> None:
        self.stream.poll_sender()

    def send(self, data: bytes) -> None:
        self.stream.enqueue_data(data)

    def receive(self, count: int) -> bytes:
        return self.stream.dequeue_data(count)

    def receive_is_available(self, count: int) -> bool:
        return self.stream.dequeue_is_available(count)

    def receive_all(self) -> bytes:
        return self.stream.dequeue_all_data()

    def close(self) -> None:
StevenWdV's avatar
StevenWdV committed
388
        logging.debug(self.deb + "Close")
389
390
391
392
393
394
395
        if type(self.parent) is _Server:
            # noinspection PyProtectedMember
            self.parent._remove_stream(self.stream.remote_port)
        else:
            # noinspection PyProtectedMember
            self.parent._remove_socket(self.stream.remote_port)
        self.stream.close()
StevenWdV's avatar
StevenWdV committed
396
        logging.debug(self.deb + "Closed")
397
398
399
400
401
402


class _Server:
    def __init__(self, binding: "Binding", local_port: int):
        self.binding = binding
        self.local_port = local_port
StevenWdV's avatar
StevenWdV committed
403
        self.streams: Dict[int, _Stream] = {}  # Remote port -> _Stream
404
405
406
407
408
        self.streams_lock = threading.Lock()

        self.backlog = 0
        self.backlog_connections: List[_Connection] = []
        self.backlog_lock = threading.Lock()
StevenWdV's avatar
StevenWdV committed
409
410
411
412
        self.backlog_available = threading.Condition(self.backlog_lock)

        self.deb = f"{self.binding.local_udp_addr}|{self.local_port}: "
        logging.debug(self.deb + "Setup server")
413
414

    def _pass_received_msgs(self, packets: List[_AddrPacket]) -> None:
StevenWdV's avatar
StevenWdV committed
415
416
        logging.debug(self.deb + f"Received {len(packets)} messages")

417
418
419
420
421
422
423
        packet_batches: Dict[int, List[_AddrPacket]] = {}
        orphan_packets: List[_AddrPacket] = []

        self.streams_lock.acquire()

        for packet in packets:
            if packet.header.src_port in self.streams:
StevenWdV's avatar
StevenWdV committed
424
                packet_batches.setdefault(packet.header.src_port, []).append(packet)
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
            else:
                orphan_packets.append(packet)

        for remote_port, packets in packet_batches.items():
            self.streams[remote_port].pass_received_msgs(packets)

        for packet in orphan_packets:
            flags = packet.header.flags_obj()
            if (flags.syn and not flags.ack and not flags.fin
                    and packet.header.data_length == 0):
                self.backlog_lock.acquire()
                if self.backlog > 0:
                    self.streams[packet.header.src_port] = stream = _Stream(
                        self.binding, self.local_port,
                        packet.header.src_port, packet.remote_udp_addr,
                        self.binding.window_size, self.binding.timeout,
                        _RemoteInit(packet.header.syn_nr, packet.header.window))
                    self.backlog_connections.append(_Connection(self, stream))
                    self.backlog -= 1
StevenWdV's avatar
StevenWdV committed
444
                    self.backlog_available.notify()
445
446
447
                    self.backlog_lock.release()
                else:
                    self.backlog_lock.release()
StevenWdV's avatar
StevenWdV committed
448
                    logging.warning(self.deb + "Backlog full")
449
            else:
StevenWdV's avatar
StevenWdV committed
450
                logging.warning(self.deb + "Unknown connection")
451
452
453
454
455
456
457
458
459
460

        self.streams_lock.release()

    def _poll_streams(self) -> None:
        self.streams_lock.acquire()
        for stream in self.streams.values():
            stream.poll_sender()
        self.streams_lock.release()

    def _remove_stream(self, remote_port: int) -> None:
StevenWdV's avatar
StevenWdV committed
461
        logging.debug(self.deb + "Remove stream")
462
463
464
465
466
        self.streams_lock.acquire()
        del self.streams[remote_port]
        self.streams_lock.release()

    def start_listen(self, backlog: int) -> None:
StevenWdV's avatar
StevenWdV committed
467
        logging.debug(self.deb + "Start listen")
468
469
470
471
472
        self.backlog_lock.acquire()
        self.backlog = backlog
        self.backlog_lock.release()

    def accept(self) -> _Connection:
StevenWdV's avatar
StevenWdV committed
473
        logging.debug(self.deb + "Accepting...")
474
        self.backlog_lock.acquire()
StevenWdV's avatar
StevenWdV committed
475
        self.backlog_available.wait_for(lambda: len(self.backlog_connections) > 0)
476
477
        connection = self.backlog_connections.pop(0)
        self.backlog_lock.release()
StevenWdV's avatar
StevenWdV committed
478
        logging.debug(self.deb + "Accepted")
479
480
481
        return connection

    def close(self) -> None:
StevenWdV's avatar
StevenWdV committed
482
        logging.debug(self.deb + "Closing...")
483
484
        # noinspection PyProtectedMember
        self.binding._remove_socket(self.local_port)
StevenWdV's avatar
StevenWdV committed
485
486
487
        while len(self.streams) > 0:
            self.streams.pop(0).close()
        logging.debug(self.deb + "Closed")
488
489
490
491


class Binding:
    def __init__(self, protocol: int, local_udp_addr: Optional[Any],
StevenWdV's avatar
StevenWdV committed
492
493
                 window_size: int, timeout_sec: float):
        self.local_udp_addr = local_udp_addr
494
        self.window_size = window_size
StevenWdV's avatar
StevenWdV committed
495
        self.timeout = timeout_sec
496
497
498
499
500

        self.sock = socket.socket(protocol, socket.SOCK_DGRAM)
        if local_udp_addr is not None:
            self.sock.bind(local_udp_addr)

StevenWdV's avatar
StevenWdV committed
501
        self.sockets: Dict[int, Union[_Server, _Connection]] = {}  # Local port -> _Server / _Connection
502
503
504
505
        self.stop = False
        self.sockets_stop_lock = threading.RLock()

        self.read_thread = threading.Thread(None, self.__background)
StevenWdV's avatar
StevenWdV committed
506
507
        self.read_thread.start()

StevenWdV's avatar
StevenWdV committed
508
509
510
        self.deb = f"{self.local_udp_addr}: "
        logging.debug(self.deb + "Set up binding")

511
    def __background(self) -> None:
StevenWdV's avatar
StevenWdV committed
512
        poller = select.poll()  # Does not work on Windows because Python is stupid (WSAPoll is a thing)
513
514
515
516
517
518
519
520
521
522
523
        poller.register(self.sock, select.POLLIN)

        while True:
            poll_result = poller.poll(100)

            packet_batches: Dict[int, List[_AddrPacket]] = {}

            self.sockets_stop_lock.acquire()
            if self.stop:
                break

StevenWdV's avatar
StevenWdV committed
524
            while len(poll_result) > 0:
525
                data, addr = self.sock.recvfrom(header_size + payload_size, socket.MSG_WAITALL)
StevenWdV's avatar
StevenWdV committed
526
527
528
529
                logging.debug(self.deb + f"Packet from {addr}")
                packet = _AddrPacket(_unpack_header(data[:header_size]),
                                     data[header_size:], addr)
                packet.data = packet.data[:-(payload_size - packet.header.data_length)]
530
531

                if not packet.verify_checksum():
StevenWdV's avatar
StevenWdV committed
532
                    logging.warning(self.deb + "Invalid checksum")
533
                else:
StevenWdV's avatar
StevenWdV committed
534
535
                    if packet.header.dest_port in self.sockets:
                        packet_batches.setdefault(packet.header.dest_port, []).append(packet)
536
                    else:
StevenWdV's avatar
StevenWdV committed
537
                        logging.warning(self.deb + "Unknown server")
538
539
540
541
542
543
544
545
546
547
548
549
550

                poll_result = poller.poll(0)

            for local_port, packets in packet_batches.items():
                # noinspection PyProtectedMember
                self.sockets[local_port]._pass_received_msgs(packets)

            for server in self.sockets.values():
                # noinspection PyProtectedMember
                server._poll_streams()

            self.sockets_stop_lock.release()

StevenWdV's avatar
StevenWdV committed
551
552
        self.sockets_stop_lock.release()

553
    def _send(self, packet: _Packet, remote_udp_addr: Any) -> None:
StevenWdV's avatar
StevenWdV committed
554
555
        logging.debug(self.deb + f"Send to {remote_udp_addr}")
        data = _pack_header(packet.header) + packet.data + bytes(payload_size - len(packet.data))
556
557
558
559
        while len(data) > 0:
            data = data[self.sock.sendto(data, remote_udp_addr):]

    def _remove_socket(self, local_port: int) -> None:
StevenWdV's avatar
StevenWdV committed
560
        logging.debug(self.deb + f"Remove socket {local_port}")
561
562
563
564
565
        self.sockets_stop_lock.acquire()
        del self.sockets[local_port]
        self.sockets_stop_lock.release()

    def bind_server(self, local_btcp_port: int) -> _Server:
StevenWdV's avatar
StevenWdV committed
566
        logging.debug(self.deb + f"Bind server to btcp {local_btcp_port}")
567
568
569
570
571
572
        self.sockets_stop_lock.acquire()
        server = self.sockets[local_btcp_port] = _Server(self, local_btcp_port)
        self.sockets_stop_lock.release()
        return server

    def connect_client(self, local_btcp_port: int, remote_btcp_port: int, remote_udp_addr: Any) -> _Connection:
StevenWdV's avatar
StevenWdV committed
573
        logging.debug(self.deb + f"Connect to btcp {local_btcp_port} -> {remote_btcp_port} {remote_udp_addr}")
574
575
576
577
578
579
580
581
        self.sockets_stop_lock.acquire()
        connection = _Connection(self, _Stream(self, local_btcp_port, remote_btcp_port, remote_udp_addr,
                                               self.window_size, self.timeout))
        self.sockets[local_btcp_port] = connection
        self.sockets_stop_lock.release()
        return connection

    def close(self) -> None:
StevenWdV's avatar
StevenWdV committed
582
        logging.debug(self.deb + "Closing...")
583
584
585
586
587
        self.sockets_stop_lock.acquire()
        self.stop = True
        self.sockets_stop_lock.release()

        self.read_thread.join()
StevenWdV's avatar
StevenWdV committed
588
        logging.debug(self.deb + "Thread exited")
589
590

        self.sockets_stop_lock.acquire()
StevenWdV's avatar
StevenWdV committed
591
592
        while len(self.sockets) > 0:
            self.sockets[0].close()
593
594
        self.sockets_stop_lock.release()
        self.sock.close()
StevenWdV's avatar
StevenWdV committed
595
        logging.debug(self.deb + "Closed")