Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Steven Wallis de Vries
bTCP
Commits
1d32f477
Commit
1d32f477
authored
Apr 25, 2019
by
StevenWdV
Browse files
First working prototype
parent
ac7b2e9b
Changes
3
Show whitespace changes
Inline
Side-by-side
bTCP_client.py
View file @
1d32f477
#!/
usr/local/
bin/python3
#!/bin/python3
import
argparse
import
logging
import
socket
import
struct
import
sys
import
btcp
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
logging
.
DEBUG
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-w"
,
"--window"
,
help
=
"Define bTCP window size"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"-t"
,
"--timeout"
,
help
=
"Define bTCP timeout in milliseconds"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"-i"
,
"--input"
,
help
=
"File to send"
,
default
=
"tmp.file"
)
args
=
parser
.
parse_args
()
binding
=
btcp
.
Binding
(
socket
.
AF_INET
,
(
""
,
9002
)
,
args
.
window
,
args
.
timeout
)
binding
=
btcp
.
Binding
(
socket
.
AF_INET
,
None
,
args
.
window
,
args
.
timeout
/
1000.
)
connection
=
binding
.
connect_client
(
0
,
0
,
(
""
,
9001
))
file
=
open
(
args
.
input
,
"rb"
)
file
=
open
(
args
.
input
,
"r
+
b"
)
data
=
file
.
read
()
file
.
close
()
connection
.
send
(
struct
.
pack
(
"!Q"
,
len
(
data
)))
connection
.
send
(
data
)
input
(
"press enter to stop
\n
"
)
binding
.
close
()
bTCP_server.py
View file @
1d32f477
#!/
usr/local/
bin/python3
#!/bin/python3
import
argparse
import
logging
import
socket
import
struct
import
sys
import
btcp
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
logging
.
DEBUG
)
# Handle arguments
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-w"
,
"--window"
,
help
=
"Define bTCP window size"
,
type
=
int
,
default
=
100
)
...
...
@@ -12,7 +16,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds
parser
.
add_argument
(
"-o"
,
"--output"
,
help
=
"Where to store file"
,
default
=
"tmp.file"
)
args
=
parser
.
parse_args
()
binding
=
btcp
.
Binding
(
socket
.
AF_INET
,
(
""
,
9001
),
args
.
window
,
args
.
timeout
)
binding
=
btcp
.
Binding
(
socket
.
AF_INET
,
(
""
,
9001
),
args
.
window
,
args
.
timeout
/
1000.
)
server
=
binding
.
bind_server
(
0
)
server
.
start_listen
(
1
)
...
...
@@ -20,11 +24,9 @@ connection = server.accept()
file_size
:
int
=
struct
.
unpack
(
"!Q"
,
connection
.
receive
(
8
))[
0
]
file
=
open
(
args
.
output
,
"wb"
)
while
file_size
>
0
:
data
=
connection
.
receive_all
()
file
.
write
(
data
)
file_size
-=
len
(
data
)
file
=
open
(
args
.
output
,
"w+b"
)
data
=
connection
.
receive
(
file_size
)
file
.
write
(
data
)
file
.
close
()
binding
.
close
()
btcp.py
View file @
1d32f477
import
binascii
import
bisect
import
copy
import
logging
import
random
import
select
import
socket
...
...
@@ -26,9 +27,19 @@ class _Flags:
else
:
self
.
syn
,
self
.
ack
,
self
.
fin
=
flags
def
to
_int
(
self
)
->
int
:
def
_
_int
__
(
self
)
->
int
:
return
self
.
syn
<<
2
|
self
.
ack
<<
1
|
self
.
fin
def
__str__
(
self
)
->
str
:
strs
:
List
[
str
]
=
[]
if
self
.
syn
:
strs
.
append
(
"SYN"
)
if
self
.
ack
:
strs
.
append
(
"ACK"
)
if
self
.
fin
:
strs
.
append
(
"FIN"
)
return
"-"
.
join
(
strs
)
class
_Header
:
def
__init__
(
self
,
src_port
:
int
,
dest_port
:
int
,
syn_nr
:
int
,
ack_nr
:
int
,
flags
:
int
,
...
...
@@ -98,6 +109,7 @@ class _AddrPacket(_Packet):
self
.
remote_udp_addr
=
remote_udp_addr
# Info for if the connection was not initiated from our side
class
_RemoteInit
:
def
__init__
(
self
,
syn_nr
:
int
,
window_size
:
int
):
self
.
syn_nr
=
syn_nr
...
...
@@ -107,13 +119,12 @@ class _RemoteInit:
class
_Stream
:
def
__init__
(
self
,
binding
:
"Binding"
,
local_port
:
int
,
remote_port
:
int
,
remote_udp_addr
:
Any
,
local_window_size
:
int
,
timeout
:
in
t
,
local_window_size
:
int
,
timeout
:
floa
t
,
remote_init
:
Optional
[
_RemoteInit
]
=
None
):
self
.
binding
=
binding
self
.
local_port
=
local_port
self
.
remote_port
=
remote_port
self
.
local_window_size
=
local_window_size
self
.
remote_window_size
=
remote_init
.
window_size
self
.
remote_udp_addr
=
remote_udp_addr
self
.
timeout
=
timeout
self
.
expect_syn_ack
=
remote_init
is
None
...
...
@@ -128,35 +139,48 @@ class _Stream:
self
.
receive_buffer_lock
=
threading
.
Lock
()
self
.
receive_bytes_available
=
threading
.
Condition
(
self
.
receive_buffer_lock
)
if
remote_init
is
None
:
self
.
send_syn_nr
=
random
.
randint
(
0
,
0xffFF
)
# Next to be sent
self
.
deb
=
f
"
{
self
.
binding
.
local_udp_addr
}
|
{
self
.
local_port
}
->
{
self
.
remote_port
}
: "
self
.
first_send_syn_nr
=
random
.
randint
(
0
,
0xffFF
)
if
remote_init
is
None
:
# We initiate the connection
self
.
send_syn_nr
=
self
.
first_send_syn_nr
# Next to be sent
self
.
send_ack_nr
=
self
.
send_syn_nr
# First not ACKed by other
self
.
__send_to_be_acked
(
b
""
,
self
.
send_syn_nr
,
_Flags
((
True
,
False
,
False
)))
else
:
self
.
remote_window_size
=
remote_init
.
window_size
self
.
recv_syn_nr
=
remote_init
.
syn_nr
# Last received
self
.
recv_ack_nr
=
remote_init
.
syn_nr
+
1
# First still to be received
self
.
send_syn_nr
=
random
.
randint
(
0
,
0xffFF
)
# Next to be sent
self
.
send_syn_nr
=
self
.
first_send_syn_nr
# Next to be sent
self
.
send_ack_nr
=
self
.
send_syn_nr
# First not ACKed by other
self
.
first_recv_syn_nr
=
self
.
recv_syn_nr
self
.
__send_to_be_acked
(
b
""
,
self
.
send_syn_nr
,
_Flags
((
True
,
True
,
False
)))
self
.
send_syn_nr
+=
1
logging
.
debug
(
self
.
deb
+
"Setup stream"
)
def
__send_to_be_acked
(
self
,
data
:
bytes
,
syn_nr
:
int
,
flags
=
_Flags
((
False
,
True
,
False
)))
->
None
:
logging
.
debug
(
self
.
deb
+
f
"Sending #
{
syn_nr
}
{
flags
}{
f
' with ACK
{
self
.
recv_ack_nr
}
' if flags.ack else ''
}
"
f
"with
{
len
(
data
)
}
bytes"
)
self
.
send_syn_nr
=
max
(
syn_nr
,
self
.
send_syn_nr
)
header
=
_Header
(
self
.
local_port
,
self
.
remote_port
,
syn_nr
,
self
.
recv_ack_nr
if
flags
.
ack
else
0
,
flags
.
to_int
(
),
self
.
local_window_size
,
len
(
data
))
packet
=
_TimestampPacket
(
header
,
data
,
time
.
clock
())
int
(
flags
),
self
.
local_window_size
,
len
(
data
))
packet
=
_TimestampPacket
(
header
,
data
,
time
.
perf_counter
())
packet
.
set_checksum
()
self
.
sent_ack_buffer
.
append
(
packet
)
# noinspection PyProtectedMember
self
.
binding
.
_send
(
packet
,
self
.
remote_udp_addr
)
def
__send_ack
(
self
)
->
None
:
logging
.
debug
(
self
.
deb
+
f
"Sending ACK
{
self
.
recv_ack_nr
}
"
)
ack_header
=
_Header
(
self
.
local_port
,
self
.
remote_port
,
self
.
send_syn_nr
,
self
.
recv_ack_nr
,
_Flags
((
False
,
True
,
False
))
.
to_int
(
),
self
.
local_window_size
,
0
)
int
(
_Flags
((
False
,
True
,
False
))),
self
.
local_window_size
,
0
)
ack_packet
=
_Packet
(
ack_header
,
b
""
)
ack_packet
.
set_checksum
()
...
...
@@ -164,41 +188,49 @@ class _Stream:
self
.
binding
.
_send
(
ack_packet
,
self
.
remote_udp_addr
)
def
pass_received_msgs
(
self
,
packets
:
List
[
_AddrPacket
])
->
None
:
logging
.
debug
(
self
.
deb
+
f
"Received
{
len
(
packets
)
}
messages"
)
send_ack
=
False
for
packet
in
packets
:
if
packet
.
remote_udp_addr
!=
self
.
remote_udp_addr
:
print
(
"Wrong remote UDP address"
)
continue
flags
=
packet
.
header
.
flags_obj
()
logging
.
debug
(
self
.
deb
+
f
"Received #
{
packet
.
header
.
syn_nr
}
{
flags
}
"
f
"
{
f
' with ACK
{
packet
.
header
.
ack_nr
}
' if flags.ack else ''
}
"
f
"with
{
len
(
packet
.
data
)
}
bytes"
)
if
self
.
expect_syn_ack
:
if
flags
.
syn
and
flags
.
ack
and
not
flags
.
fin
:
if
packet
.
header
.
ack_nr
!=
self
.
send_ack_nr
:
print
(
"Wrong ACK nr in SYN-ACK"
)
if
packet
.
header
.
ack_nr
!=
self
.
send_ack_nr
+
1
:
logging
.
warning
(
self
.
deb
+
"Wrong ACK nr in SYN-ACK"
)
continue
self
.
expect_syn_ack
=
False
self
.
recv_syn_nr
=
packet
.
header
.
syn_nr
self
.
recv_ack_nr
=
packet
.
header
.
syn_nr
+
1
self
.
first_recv_syn_nr
=
self
.
recv_syn_nr
self
.
remote_window_size
=
packet
.
header
.
window
else
:
print
(
"SYN-ACK expected"
)
logging
.
warning
(
self
.
deb
+
"SYN-ACK expected"
)
continue
elif
flags
.
syn
:
print
(
"Unexpected SYN"
)
# TODO simultaneous open?
if
packet
.
header
.
syn_nr
==
self
.
first_recv_syn_nr
:
logging
.
debug
(
self
.
deb
+
"Spurious SYN"
)
continue
else
:
logging
.
warning
(
self
.
deb
+
"Unexpected SYN"
)
# TODO simultaneous open
continue
if
packet
.
header
.
syn_nr
>
self
.
recv_syn_nr
+
self
.
local_window_size
:
print
(
"Too high SYN"
)
logging
.
warning
(
self
.
deb
+
"Too high SYN"
)
continue
# TODO if flags.fin:
if
flags
.
ack
:
if
packet
.
header
.
ack_nr
>
self
.
send_syn_nr
+
1
:
print
(
"Too high ACK"
)
logging
.
warning
(
self
.
deb
+
"Too high ACK"
)
continue
if
packet
.
header
.
ack_nr
<
self
.
first_send_syn_nr
:
logging
.
warning
(
self
.
deb
+
"Too low ACK"
)
continue
if
packet
.
header
.
ack_nr
>
self
.
send_ack_nr
:
self
.
send_ack_nr
=
packet
.
header
.
ack_nr
...
...
@@ -225,9 +257,12 @@ class _Stream:
elif
packet
.
header
.
syn_nr
>
self
.
recv_ack_nr
:
# There is a gap, store this packet
insert_index
=
bisect
.
bisect_left
(
self
.
receive_ack_buffer
,
packet
)
# Check for duplicate packet
if
(
insert_index
+
1
>
len
(
self
.
receive_ack_buffer
)
or
self
.
receive_ack_buffer
[
insert_index
+
1
].
header
.
syn_nr
!=
packet
.
header
.
syn_nr
):
if
(
insert_index
+
1
<
len
(
self
.
receive_ack_buffer
)
and
self
.
receive_ack_buffer
[
insert_index
+
1
].
header
.
syn_nr
==
packet
.
header
.
syn_nr
):
assert
self
.
receive_ack_buffer
[
insert_index
+
1
].
data
==
packet
.
data
logging
.
debug
(
self
.
deb
+
"Spurious retransmission"
)
else
:
self
.
receive_ack_buffer
.
insert
(
insert_index
,
packet
)
send_ack
=
True
...
...
@@ -237,27 +272,34 @@ class _Stream:
self
.
send_buffer_lock
.
acquire
()
data_to_send
=
len
(
self
.
send_buffer
)
>
0
self
.
send_buffer_lock
.
release
()
if
data_to_send
:
if
not
data_to_send
:
self
.
__send_ack
()
def
poll_sender
(
self
)
->
None
:
lost
=
[
p
for
p
in
self
.
sent_ack_buffer
if
time
.
perf_counter
()
-
p
.
timestamp
>=
self
.
timeout
]
self
.
sent_ack_buffer
=
[
p
for
p
in
self
.
sent_ack_buffer
if
time
.
perf_counter
()
-
p
.
timestamp
<
self
.
timeout
]
if
len
(
lost
)
>
0
:
logging
.
debug
(
self
.
deb
+
f
"
{
len
(
lost
)
}
lost messages"
)
for
p
in
lost
:
p
.
timestamp
=
time
.
perf_counter
()
self
.
__send_to_be_acked
(
p
.
data
,
p
.
header
.
syn_nr
,
p
.
header
.
flags_obj
())
if
self
.
expect_syn_ack
:
return
lost
=
[
p
for
p
in
self
.
sent_ack_buffer
if
time
.
clock
()
-
p
.
timestamp
>=
self
.
timeout
]
for
p
in
lost
:
p
.
timestamp
=
time
.
clock
()
self
.
__send_to_be_acked
(
p
.
data
,
p
.
header
.
syn_nr
,
p
.
header
.
flags_obj
())
if
len
(
self
.
send_buffer
)
>
0
:
logging
.
debug
(
self
.
deb
+
f
"
{
len
(
self
.
send_buffer
)
}
unsent messages"
)
if
len
(
self
.
sent_ack_buffer
)
<
self
.
remote_window_size
:
self
.
send_buffer_lock
.
acquire
()
while
len
(
self
.
sent_ack_buffer
)
<
self
.
remote_window_size
:
while
len
(
self
.
send_buffer
)
>
0
and
len
(
self
.
sent_ack_buffer
)
<
self
.
remote_window_size
:
data
=
self
.
send_buffer
.
pop
(
0
)
self
.
__send_to_be_acked
(
data
+
bytes
(
payload_size
-
len
(
data
))
,
self
.
send_syn_nr
)
self
.
__send_to_be_acked
(
data
,
self
.
send_syn_nr
)
self
.
send_syn_nr
+=
1
self
.
send_buffer_lock
.
release
()
def
enqueue_data
(
self
,
data
:
bytes
)
->
None
:
logging
.
debug
(
self
.
deb
+
f
"Enqueuing
{
len
(
data
)
}
bytes"
)
self
.
send_buffer_lock
.
acquire
()
offset
=
0
while
offset
<
len
(
data
):
...
...
@@ -268,11 +310,12 @@ class _Stream:
def
__at_least_received
(
self
,
count
:
int
)
->
bool
:
i
=
0
while
i
<
len
(
self
.
receive_buffer
)
and
count
>
0
:
count
+
=
len
(
self
.
receive_buffer
[
i
])
count
-
=
len
(
self
.
receive_buffer
[
i
])
i
+=
1
return
count
<=
0
def
dequeue_data
(
self
,
count
:
int
)
->
bytes
:
logging
.
debug
(
self
.
deb
+
f
"Dequeuing
{
count
}
bytes..."
)
self
.
receive_buffer_lock
.
acquire
()
self
.
receive_bytes_available
.
wait_for
(
lambda
:
self
.
__at_least_received
(
count
))
...
...
@@ -281,10 +324,12 @@ class _Stream:
data
+=
self
.
receive_buffer
.
pop
(
0
)
if
len
(
data
)
<
count
:
data
+=
self
.
receive_buffer
[
0
][:
len
(
data
)
-
count
]
self
.
receive_buffer
[
0
]
=
self
.
receive_buffer
[
0
][
len
(
data
)
-
count
:]
bytes_short
=
count
-
len
(
data
)
data
+=
self
.
receive_buffer
[
0
][:
bytes_short
]
self
.
receive_buffer
[
0
]
=
self
.
receive_buffer
[
0
][
bytes_short
:]
self
.
receive_buffer_lock
.
release
()
logging
.
debug
(
self
.
deb
+
"Dequeued"
)
return
data
def
dequeue_is_available
(
self
,
count
:
int
)
->
bool
:
...
...
@@ -294,14 +339,17 @@ class _Stream:
return
available
def
dequeue_all_data
(
self
)
->
bytes
:
logging
.
debug
(
self
.
deb
+
"Dequeuing all data..."
)
self
.
receive_buffer_lock
.
acquire
()
data
=
b
""
while
len
(
self
.
receive_buffer
)
>
0
:
data
+=
self
.
receive_buffer
.
pop
(
0
)
self
.
receive_buffer_lock
.
release
()
logging
.
debug
(
self
.
deb
+
f
"Dequeued
{
len
(
data
)
}
bytes"
)
return
data
def
close
(
self
)
->
None
:
logging
.
debug
(
self
.
deb
+
"Close"
)
# TODO FIN
pass
...
...
@@ -310,6 +358,11 @@ class _Connection:
def
__init__
(
self
,
parent
:
Union
[
"_Server"
,
"Binding"
],
stream
:
_Stream
):
self
.
parent
=
parent
self
.
stream
=
stream
self
.
deb
=
f
"
{
self
.
__parent_str
()
}
|
{
self
.
stream
.
local_port
}
->
{
self
.
stream
.
remote_port
}
: "
logging
.
debug
(
self
.
deb
+
"Setup connection"
)
def
__parent_str
(
self
)
->
str
:
return
self
.
parent
.
local_udp_addr
if
type
(
self
.
parent
)
is
Binding
else
self
.
parent
.
binding
.
local_udp_addr
def
_pass_received_msgs
(
self
,
packets
:
List
[
_AddrPacket
])
->
None
:
self
.
stream
.
pass_received_msgs
(
packets
)
...
...
@@ -330,6 +383,7 @@ class _Connection:
return
self
.
stream
.
dequeue_all_data
()
def
close
(
self
)
->
None
:
logging
.
debug
(
self
.
deb
+
"Close"
)
if
type
(
self
.
parent
)
is
_Server
:
# noinspection PyProtectedMember
self
.
parent
.
_remove_stream
(
self
.
stream
.
remote_port
)
...
...
@@ -337,20 +391,27 @@ class _Connection:
# noinspection PyProtectedMember
self
.
parent
.
_remove_socket
(
self
.
stream
.
remote_port
)
self
.
stream
.
close
()
logging
.
debug
(
self
.
deb
+
"Closed"
)
class
_Server
:
def
__init__
(
self
,
binding
:
"Binding"
,
local_port
:
int
):
self
.
binding
=
binding
self
.
local_port
=
local_port
self
.
streams
:
Dict
[
int
,
_Stream
]
=
{}
self
.
streams
:
Dict
[
int
,
_Stream
]
=
{}
# Remote port -> _Stream
self
.
streams_lock
=
threading
.
Lock
()
self
.
backlog
=
0
self
.
backlog_connections
:
List
[
_Connection
]
=
[]
self
.
backlog_lock
=
threading
.
Lock
()
self
.
backlog_available
=
threading
.
Condition
(
self
.
backlog_lock
)
self
.
deb
=
f
"
{
self
.
binding
.
local_udp_addr
}
|
{
self
.
local_port
}
: "
logging
.
debug
(
self
.
deb
+
"Setup server"
)
def
_pass_received_msgs
(
self
,
packets
:
List
[
_AddrPacket
])
->
None
:
logging
.
debug
(
self
.
deb
+
f
"Received
{
len
(
packets
)
}
messages"
)
packet_batches
:
Dict
[
int
,
List
[
_AddrPacket
]]
=
{}
orphan_packets
:
List
[
_AddrPacket
]
=
[]
...
...
@@ -358,8 +419,7 @@ class _Server:
for
packet
in
packets
:
if
packet
.
header
.
src_port
in
self
.
streams
:
batch
=
packet_batches
[
packet
.
header
.
src_port
]
=
[]
batch
.
append
(
packet
)
packet_batches
.
setdefault
(
packet
.
header
.
src_port
,
[]).
append
(
packet
)
else
:
orphan_packets
.
append
(
packet
)
...
...
@@ -379,12 +439,13 @@ class _Server:
_RemoteInit
(
packet
.
header
.
syn_nr
,
packet
.
header
.
window
))
self
.
backlog_connections
.
append
(
_Connection
(
self
,
stream
))
self
.
backlog
-=
1
self
.
backlog_available
.
notify
()
self
.
backlog_lock
.
release
()
else
:
self
.
backlog_lock
.
release
()
print
(
"Backlog full"
)
logging
.
warning
(
self
.
deb
+
"Backlog full"
)
else
:
print
(
"Unknown connection"
)
logging
.
warning
(
self
.
deb
+
"Unknown connection"
)
self
.
streams_lock
.
release
()
...
...
@@ -395,47 +456,58 @@ class _Server:
self
.
streams_lock
.
release
()
def
_remove_stream
(
self
,
remote_port
:
int
)
->
None
:
logging
.
debug
(
self
.
deb
+
"Remove stream"
)
self
.
streams_lock
.
acquire
()
del
self
.
streams
[
remote_port
]
self
.
streams_lock
.
release
()
def
start_listen
(
self
,
backlog
:
int
)
->
None
:
logging
.
debug
(
self
.
deb
+
"Start listen"
)
self
.
backlog_lock
.
acquire
()
self
.
backlog
=
backlog
self
.
backlog_lock
.
release
()
def
accept
(
self
)
->
_Connection
:
logging
.
debug
(
self
.
deb
+
"Accepting..."
)
self
.
backlog_lock
.
acquire
()
self
.
backlog_available
.
wait_for
(
lambda
:
len
(
self
.
backlog_connections
)
>
0
)
connection
=
self
.
backlog_connections
.
pop
(
0
)
self
.
backlog_lock
.
release
()
logging
.
debug
(
self
.
deb
+
"Accepted"
)
return
connection
def
close
(
self
)
->
None
:
logging
.
debug
(
self
.
deb
+
"Closing..."
)
# noinspection PyProtectedMember
self
.
binding
.
_remove_socket
(
self
.
local_port
)
for
stream
in
self
.
streams
.
values
():
stream
.
close
()
while
len
(
self
.
streams
)
>
0
:
self
.
streams
.
pop
(
0
).
close
()
logging
.
debug
(
self
.
deb
+
"Closed"
)
class
Binding
:
def
__init__
(
self
,
protocol
:
int
,
local_udp_addr
:
Optional
[
Any
],
window_size
:
int
,
timeout
:
int
):
window_size
:
int
,
timeout_sec
:
float
):
self
.
local_udp_addr
=
local_udp_addr
self
.
window_size
=
window_size
self
.
timeout
=
timeout
self
.
timeout
=
timeout
_sec
self
.
sock
=
socket
.
socket
(
protocol
,
socket
.
SOCK_DGRAM
)
if
local_udp_addr
is
not
None
:
self
.
sock
.
bind
(
local_udp_addr
)
self
.
sockets
:
Dict
[
int
,
Union
[
_Server
,
_Connection
]]
=
{}
self
.
sockets
:
Dict
[
int
,
Union
[
_Server
,
_Connection
]]
=
{}
# Local port -> _Server / _Connection
self
.
stop
=
False
self
.
sockets_stop_lock
=
threading
.
RLock
()
self
.
read_thread
=
threading
.
Thread
(
None
,
self
.
__background
)
self
.
read_thread
.
start
()
self
.
deb
=
f
"
{
self
.
local_udp_addr
}
: "
logging
.
debug
(
self
.
deb
+
"Set up binding"
)
def
__background
(
self
)
->
None
:
poller
=
select
.
poll
()
poller
=
select
.
poll
()
# Does not work on Windows because Python is stupid (WSAPoll is a thing)
poller
.
register
(
self
.
sock
,
select
.
POLLIN
)
while
True
:
...
...
@@ -447,18 +519,20 @@ class Binding:
if
self
.
stop
:
break
while
(
self
.
sock
,
select
.
POLLIN
)
in
poll_result
:
while
len
(
poll_result
)
>
0
:
data
,
addr
=
self
.
sock
.
recvfrom
(
header_size
+
payload_size
,
socket
.
MSG_WAITALL
)
packet
=
_AddrPacket
(
_unpack_header
(
data
[:
header_size
]),
data
[
header_size
:],
addr
)
logging
.
debug
(
self
.
deb
+
f
"Packet from
{
addr
}
"
)
packet
=
_AddrPacket
(
_unpack_header
(
data
[:
header_size
]),
data
[
header_size
:],
addr
)
packet
.
data
=
packet
.
data
[:
-
(
payload_size
-
packet
.
header
.
data_length
)]
if
not
packet
.
verify_checksum
():
print
(
"Invalid checksum"
)
logging
.
warning
(
self
.
deb
+
"Invalid checksum"
)
else
:
if
packet
.
header
.
src_port
in
self
.
sockets
:
batch
=
packet_batches
[
packet
.
header
.
src_port
]
=
[]
batch
.
append
(
packet
)
if
packet
.
header
.
dest_port
in
self
.
sockets
:
packet_batches
.
setdefault
(
packet
.
header
.
dest_port
,
[]).
append
(
packet
)
else
:
print
(
"Unknown server"
)
logging
.
warning
(
self
.
deb
+
"Unknown server"
)
poll_result
=
poller
.
poll
(
0
)
...
...
@@ -472,23 +546,29 @@ class Binding:
self
.
sockets_stop_lock
.
release
()
self
.
sockets_stop_lock
.
release
()
def
_send
(
self
,
packet
:
_Packet
,
remote_udp_addr
:
Any
)
->
None
:
data
=
_pack_header
(
packet
.
header
)
+
packet
.
data
logging
.
debug
(
self
.
deb
+
f
"Send to
{
remote_udp_addr
}
"
)
data
=
_pack_header
(
packet
.
header
)
+
packet
.
data
+
bytes
(
payload_size
-
len
(
packet
.
data
))
while
len
(
data
)
>
0
:
data
=
data
[
self
.
sock
.
sendto
(
data
,
remote_udp_addr
):]
def
_remove_socket
(
self
,
local_port
:
int
)
->
None
:
logging
.
debug
(
self
.
deb
+
f
"Remove socket
{
local_port
}
"
)
self
.
sockets_stop_lock
.
acquire
()
del
self
.
sockets
[
local_port
]
self
.
sockets_stop_lock
.
release
()
def
bind_server
(
self
,
local_btcp_port
:
int
)
->
_Server
:
logging
.
debug
(
self
.
deb
+
f
"Bind server to btcp
{
local_btcp_port
}
"
)
self
.
sockets_stop_lock
.
acquire
()
server
=
self
.
sockets
[
local_btcp_port
]
=
_Server
(
self
,
local_btcp_port
)
self
.
sockets_stop_lock
.
release
()
return
server
def
connect_client
(
self
,
local_btcp_port
:
int
,
remote_btcp_port
:
int
,
remote_udp_addr
:
Any
)
->
_Connection
:
logging
.
debug
(
self
.
deb
+
f
"Connect to btcp
{
local_btcp_port
}
->
{
remote_btcp_port
}
{
remote_udp_addr
}
"
)
self
.
sockets_stop_lock
.
acquire
()
connection
=
_Connection
(
self
,
_Stream
(
self
,
local_btcp_port
,
remote_btcp_port
,
remote_udp_addr
,
self
.
window_size
,
self
.
timeout
))
...
...
@@ -497,14 +577,17 @@ class Binding:
return
connection
def
close
(
self
)
->
None
:
logging
.
debug
(
self
.
deb
+
"Closing..."
)
self
.
sockets_stop_lock
.
acquire
()
self
.
stop
=
True
self
.
sockets_stop_lock
.
release
()
self
.
read_thread
.
join
()
logging
.
debug
(
self
.
deb
+
"Thread exited"
)
self
.
sockets_stop_lock
.
acquire
()
for
server
in
self
.
sockets
.
values
()
:
se
rver
.
close
()
while
len
(
self
.
sockets
)
>
0
:
se
lf
.
sockets
[
0
]
.
close
()
self
.
sockets_stop_lock
.
release
()
self
.
sock
.
close
()
logging
.
debug
(
self
.
deb
+
"Closed"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment