Commit b8793705 authored by StevenWdV's avatar StevenWdV

Moved header packing functions into _Header, decreased first SYN nr range,...

Moved header packing functions into _Header, decreased first SYN nr range, added too low SYN check, added polling time param, added data_length check
parent f7bb06c0
......@@ -15,7 +15,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds
parser.add_argument("-i", "--input", help="File to send", default="tmp.file")
args = parser.parse_args()
binding = btcp.Binding(socket.AF_INET, None, args.window, args.timeout / 1000.)
binding = btcp.Binding(socket.AF_INET, None, args.window, args.timeout)
connection = binding.connect_client(0, 0, ("", 9001))
file = open(args.input, "r+b")
......
......@@ -16,7 +16,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds
parser.add_argument("-o", "--output", help="Where to store file", default="tmp.file")
args = parser.parse_args()
binding = btcp.Binding(socket.AF_INET, ("", 9001), args.window, args.timeout / 1000.)
binding = btcp.Binding(socket.AF_INET, ("", 9001), args.window, args.timeout)
server = binding.bind_server(0)
server.start_listen(1)
......
......@@ -42,6 +42,10 @@ class _Flags:
class _Header:
@staticmethod
def unpack(header: bytes) -> "_Header":
return _Header(*struct.unpack(header_format, header))
def __init__(self, src_port: int, dest_port: int, syn_nr: int, ack_nr: int, flags: int,
window: int, data_length: int, checksum: int = 0):
self.dest_port = dest_port
......@@ -56,22 +60,17 @@ class _Header:
def flags_obj(self) -> _Flags:
return _Flags(self.flags)
def _pack_header(header: _Header) -> bytes:
return struct.pack(
header_format,
header.src_port,
header.dest_port,
header.syn_nr,
header.ack_nr,
header.flags,
header.window,
header.data_length,
header.checksum)
def _unpack_header(header: bytes) -> _Header:
return _Header(*struct.unpack(header_format, header))
def __bytes__(self) -> bytes:
return struct.pack(
header_format,
self.src_port,
self.dest_port,
self.syn_nr,
self.ack_nr,
self.flags,
self.window,
self.data_length,
self.checksum)
class _Packet:
......@@ -85,7 +84,7 @@ class _Packet:
else:
header_no_checksum = copy.deepcopy(self.header)
header_no_checksum.checksum = 0
return binascii.crc32(_pack_header(header_no_checksum) + self.data)
return binascii.crc32(bytes(header_no_checksum) + self.data)
def verify_checksum(self) -> bool:
return self.compute_checksum() == self.header.checksum
......@@ -116,6 +115,8 @@ class _RemoteInit:
self.window_size = window_size
# TODO more timeouts
# TODO fast retransmit
class _Stream:
def __init__(self, binding: "Binding", local_port: int,
remote_port: int, remote_udp_addr: Any,
......@@ -141,7 +142,7 @@ class _Stream:
self.deb = f"{self.binding.local_udp_addr}|{self.local_port}->{self.remote_port}: "
self.first_send_syn_nr = random.randint(0, 0xffFF)
self.first_send_syn_nr = random.randint(0, 0xff)
if remote_init is None: # We initiate the connection
self.send_syn_nr = self.first_send_syn_nr # Next to be sent
......@@ -221,6 +222,10 @@ class _Stream:
# TODO simultaneous open if self.first_recv_syn_nr is None
continue
if packet.header.syn_nr < self.first_recv_syn_nr:
logging.warning(self.deb + "Too low SYN")
continue
if packet.header.syn_nr > self.recv_syn_nr + self.local_window_size:
logging.warning(self.deb + "Too high SYN")
continue
......@@ -437,7 +442,7 @@ class _Server:
self.streams[packet.header.src_port] = stream = _Stream(
self.binding, self.local_port,
packet.header.src_port, packet.remote_udp_addr,
self.binding.window_size, self.binding.timeout,
self.binding.window_size, self.binding.timeout_sec,
_RemoteInit(packet.header.syn_nr, packet.header.window))
self.backlog_connections.append(_Connection(self, stream))
self.backlog -= 1
......@@ -489,10 +494,11 @@ class _Server:
class Binding:
def __init__(self, protocol: int, local_udp_addr: Optional[Any],
window_size: int, timeout_sec: float):
window_size: int, timeout_ms: int, poll_time_ms: Optional[int] = None):
self.local_udp_addr = local_udp_addr
self.window_size = window_size
self.timeout = timeout_sec
self.timeout_sec = timeout_ms / 1000
self.poll_time_ms = poll_time_ms or timeout_ms
self.sock = socket.socket(protocol, socket.SOCK_DGRAM)
if local_udp_addr is not None:
......@@ -513,7 +519,7 @@ class Binding:
poller.register(self.sock, select.POLLIN)
while True:
poll_result = poller.poll(100)
poll_result = poller.poll(self.poll_time_ms)
packet_batches: Dict[int, List[_AddrPacket]] = {}
......@@ -524,8 +530,10 @@ class Binding:
while len(poll_result) > 0:
data, addr = self.sock.recvfrom(header_size + payload_size, socket.MSG_WAITALL)
logging.debug(self.deb + f"Packet from {addr}")
packet = _AddrPacket(_unpack_header(data[:header_size]),
packet = _AddrPacket(_Header.unpack(data[:header_size]),
data[header_size:], addr)
if packet.header.data_length > payload_size:
logging.warning(self.deb + "data_length too large")
packet.data = packet.data[:-(payload_size - packet.header.data_length)]
if not packet.verify_checksum():
......@@ -552,7 +560,7 @@ class Binding:
def _send(self, packet: _Packet, remote_udp_addr: Any) -> None:
logging.debug(self.deb + f"Send to {remote_udp_addr}")
data = _pack_header(packet.header) + packet.data + bytes(payload_size - len(packet.data))
data = bytes(packet.header) + packet.data + bytes(payload_size - len(packet.data))
while len(data) > 0:
data = data[self.sock.sendto(data, remote_udp_addr):]
......@@ -573,7 +581,7 @@ class Binding:
logging.debug(self.deb + f"Connect to btcp {local_btcp_port} -> {remote_btcp_port} {remote_udp_addr}")
self.sockets_stop_lock.acquire()
connection = _Connection(self, _Stream(self, local_btcp_port, remote_btcp_port, remote_udp_addr,
self.window_size, self.timeout))
self.window_size, self.timeout_sec))
self.sockets[local_btcp_port] = connection
self.sockets_stop_lock.release()
return connection
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment