Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Steven Wallis de Vries
bTCP
Commits
1d32f477
Commit
1d32f477
authored
Apr 25, 2019
by
StevenWdV
Browse files
First working prototype
parent
ac7b2e9b
Changes
3
Hide whitespace changes
Inline
Side-by-side
bTCP_client.py
View file @
1d32f477
#!/
usr/local/
bin/python3
#!/bin/python3
import
argparse
import
argparse
import
logging
import
socket
import
socket
import
struct
import
struct
import
sys
import
btcp
import
btcp
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
logging
.
DEBUG
)
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-w"
,
"--window"
,
help
=
"Define bTCP window size"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"-w"
,
"--window"
,
help
=
"Define bTCP window size"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"-t"
,
"--timeout"
,
help
=
"Define bTCP timeout in milliseconds"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"-t"
,
"--timeout"
,
help
=
"Define bTCP timeout in milliseconds"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"-i"
,
"--input"
,
help
=
"File to send"
,
default
=
"tmp.file"
)
parser
.
add_argument
(
"-i"
,
"--input"
,
help
=
"File to send"
,
default
=
"tmp.file"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
binding
=
btcp
.
Binding
(
socket
.
AF_INET
,
(
""
,
9002
)
,
args
.
window
,
args
.
timeout
)
binding
=
btcp
.
Binding
(
socket
.
AF_INET
,
None
,
args
.
window
,
args
.
timeout
/
1000.
)
connection
=
binding
.
connect_client
(
0
,
0
,
(
""
,
9001
))
connection
=
binding
.
connect_client
(
0
,
0
,
(
""
,
9001
))
file
=
open
(
args
.
input
,
"rb"
)
file
=
open
(
args
.
input
,
"r
+
b"
)
data
=
file
.
read
()
data
=
file
.
read
()
file
.
close
()
file
.
close
()
connection
.
send
(
struct
.
pack
(
"!Q"
,
len
(
data
)))
connection
.
send
(
struct
.
pack
(
"!Q"
,
len
(
data
)))
connection
.
send
(
data
)
connection
.
send
(
data
)
input
(
"press enter to stop
\n
"
)
binding
.
close
()
binding
.
close
()
bTCP_server.py
View file @
1d32f477
#!/
usr/local/
bin/python3
#!/bin/python3
import
argparse
import
argparse
import
logging
import
socket
import
socket
import
struct
import
struct
import
sys
import
btcp
import
btcp
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
logging
.
DEBUG
)
# Handle arguments
# Handle arguments
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-w"
,
"--window"
,
help
=
"Define bTCP window size"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"-w"
,
"--window"
,
help
=
"Define bTCP window size"
,
type
=
int
,
default
=
100
)
...
@@ -12,7 +16,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds
...
@@ -12,7 +16,7 @@ parser.add_argument("-t", "--timeout", help="Define bTCP timeout in milliseconds
parser
.
add_argument
(
"-o"
,
"--output"
,
help
=
"Where to store file"
,
default
=
"tmp.file"
)
parser
.
add_argument
(
"-o"
,
"--output"
,
help
=
"Where to store file"
,
default
=
"tmp.file"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
binding
=
btcp
.
Binding
(
socket
.
AF_INET
,
(
""
,
9001
),
args
.
window
,
args
.
timeout
)
binding
=
btcp
.
Binding
(
socket
.
AF_INET
,
(
""
,
9001
),
args
.
window
,
args
.
timeout
/
1000.
)
server
=
binding
.
bind_server
(
0
)
server
=
binding
.
bind_server
(
0
)
server
.
start_listen
(
1
)
server
.
start_listen
(
1
)
...
@@ -20,11 +24,9 @@ connection = server.accept()
...
@@ -20,11 +24,9 @@ connection = server.accept()
file_size
:
int
=
struct
.
unpack
(
"!Q"
,
connection
.
receive
(
8
))[
0
]
file_size
:
int
=
struct
.
unpack
(
"!Q"
,
connection
.
receive
(
8
))[
0
]
file
=
open
(
args
.
output
,
"wb"
)
file
=
open
(
args
.
output
,
"w+b"
)
while
file_size
>
0
:
data
=
connection
.
receive
(
file_size
)
data
=
connection
.
receive_all
()
file
.
write
(
data
)
file
.
write
(
data
)
file_size
-=
len
(
data
)
file
.
close
()
file
.
close
()
binding
.
close
()
binding
.
close
()
btcp.py
View file @
1d32f477
import
binascii
import
binascii
import
bisect
import
bisect
import
copy
import
copy
import
logging
import
random
import
random
import
select
import
select
import
socket
import
socket
...
@@ -26,9 +27,19 @@ class _Flags:
...
@@ -26,9 +27,19 @@ class _Flags:
else
:
else
:
self
.
syn
,
self
.
ack
,
self
.
fin
=
flags
self
.
syn
,
self
.
ack
,
self
.
fin
=
flags
def
to
_int
(
self
)
->
int
:
def
_
_int
__
(
self
)
->
int
:
return
self
.
syn
<<
2
|
self
.
ack
<<
1
|
self
.
fin
return
self
.
syn
<<
2
|
self
.
ack
<<
1
|
self
.
fin
def
__str__
(
self
)
->
str
:
strs
:
List
[
str
]
=
[]
if
self
.
syn
:
strs
.
append
(
"SYN"
)
if
self
.
ack
:
strs
.
append
(
"ACK"
)
if
self
.
fin
:
strs
.
append
(
"FIN"
)
return
"-"
.
join
(
strs
)
class
_Header
:
class
_Header
:
def
__init__
(
self
,
src_port
:
int
,
dest_port
:
int
,
syn_nr
:
int
,
ack_nr
:
int
,
flags
:
int
,
def
__init__
(
self
,
src_port
:
int
,
dest_port
:
int
,
syn_nr
:
int
,
ack_nr
:
int
,
flags
:
int
,
...
@@ -98,6 +109,7 @@ class _AddrPacket(_Packet):
...
@@ -98,6 +109,7 @@ class _AddrPacket(_Packet):
self
.
remote_udp_addr
=
remote_udp_addr
self
.
remote_udp_addr
=
remote_udp_addr
# Info for if the connection was not initiated from our side
class
_RemoteInit
:
class
_RemoteInit
:
def
__init__
(
self
,
syn_nr
:
int
,
window_size
:
int
):
def
__init__
(
self
,
syn_nr
:
int
,
window_size
:
int
):
self
.
syn_nr
=
syn_nr
self
.
syn_nr
=
syn_nr
...
@@ -107,13 +119,12 @@ class _RemoteInit:
...
@@ -107,13 +119,12 @@ class _RemoteInit:
class
_Stream
:
class
_Stream
:
def
__init__
(
self
,
binding
:
"Binding"
,
local_port
:
int
,
def
__init__
(
self
,
binding
:
"Binding"
,
local_port
:
int
,
remote_port
:
int
,
remote_udp_addr
:
Any
,
remote_port
:
int
,
remote_udp_addr
:
Any
,
local_window_size
:
int
,
timeout
:
in
t
,
local_window_size
:
int
,
timeout
:
floa
t
,
remote_init
:
Optional
[
_RemoteInit
]
=
None
):
remote_init
:
Optional
[
_RemoteInit
]
=
None
):
self
.
binding
=
binding
self
.
binding
=
binding
self
.
local_port
=
local_port
self
.
local_port
=
local_port
self
.
remote_port
=
remote_port
self
.
remote_port
=
remote_port
self
.
local_window_size
=
local_window_size
self
.
local_window_size
=
local_window_size
self
.
remote_window_size
=
remote_init
.
window_size
self
.
remote_udp_addr
=
remote_udp_addr
self
.
remote_udp_addr
=
remote_udp_addr
self
.
timeout
=
timeout
self
.
timeout
=
timeout
self
.
expect_syn_ack
=
remote_init
is
None
self
.
expect_syn_ack
=
remote_init
is
None
...
@@ -128,35 +139,48 @@ class _Stream:
...
@@ -128,35 +139,48 @@ class _Stream:
self
.
receive_buffer_lock
=
threading
.
Lock
()
self
.
receive_buffer_lock
=
threading
.
Lock
()
self
.
receive_bytes_available
=
threading
.
Condition
(
self
.
receive_buffer_lock
)
self
.
receive_bytes_available
=
threading
.
Condition
(
self
.
receive_buffer_lock
)
if
remote_init
is
None
:
self
.
deb
=
f
"
{
self
.
binding
.
local_udp_addr
}
|
{
self
.
local_port
}
->
{
self
.
remote_port
}
: "
self
.
send_syn_nr
=
random
.
randint
(
0
,
0xffFF
)
# Next to be sent
self
.
first_send_syn_nr
=
random
.
randint
(
0
,
0xffFF
)
if
remote_init
is
None
:
# We initiate the connection
self
.
send_syn_nr
=
self
.
first_send_syn_nr
# Next to be sent
self
.
send_ack_nr
=
self
.
send_syn_nr
# First not ACKed by other
self
.
send_ack_nr
=
self
.
send_syn_nr
# First not ACKed by other
self
.
__send_to_be_acked
(
b
""
,
self
.
send_syn_nr
,
_Flags
((
True
,
False
,
False
)))
self
.
__send_to_be_acked
(
b
""
,
self
.
send_syn_nr
,
_Flags
((
True
,
False
,
False
)))
else
:
else
:
self
.
remote_window_size
=
remote_init
.
window_size
self
.
recv_syn_nr
=
remote_init
.
syn_nr
# Last received
self
.
recv_syn_nr
=
remote_init
.
syn_nr
# Last received
self
.
recv_ack_nr
=
remote_init
.
syn_nr
+
1
# First still to be received
self
.
recv_ack_nr
=
remote_init
.
syn_nr
+
1
# First still to be received
self
.
send_syn_nr
=
random
.
randint
(
0
,
0xffFF
)
# Next to be sent
self
.
send_syn_nr
=
self
.
first_send_syn_nr
# Next to be sent
self
.
send_ack_nr
=
self
.
send_syn_nr
# First not ACKed by other
self
.
send_ack_nr
=
self
.
send_syn_nr
# First not ACKed by other
self
.
first_recv_syn_nr
=
self
.
recv_syn_nr
self
.
__send_to_be_acked
(
b
""
,
self
.
send_syn_nr
,
_Flags
((
True
,
True
,
False
)))
self
.
__send_to_be_acked
(
b
""
,
self
.
send_syn_nr
,
_Flags
((
True
,
True
,
False
)))
self
.
send_syn_nr
+=
1
self
.
send_syn_nr
+=
1
logging
.
debug
(
self
.
deb
+
"Setup stream"
)
def
__send_to_be_acked
(
self
,
data
:
bytes
,
syn_nr
:
int
,
flags
=
_Flags
((
False
,
True
,
False
)))
->
None
:
def
__send_to_be_acked
(
self
,
data
:
bytes
,
syn_nr
:
int
,
flags
=
_Flags
((
False
,
True
,
False
)))
->
None
:
logging
.
debug
(
self
.
deb
+
f
"Sending #
{
syn_nr
}
{
flags
}{
f
' with ACK
{
self
.
recv_ack_nr
}
' if flags.ack else ''
}
"
f
"with
{
len
(
data
)
}
bytes"
)
self
.
send_syn_nr
=
max
(
syn_nr
,
self
.
send_syn_nr
)
self
.
send_syn_nr
=
max
(
syn_nr
,
self
.
send_syn_nr
)
header
=
_Header
(
self
.
local_port
,
self
.
remote_port
,
syn_nr
,
self
.
recv_ack_nr
if
flags
.
ack
else
0
,
header
=
_Header
(
self
.
local_port
,
self
.
remote_port
,
syn_nr
,
self
.
recv_ack_nr
if
flags
.
ack
else
0
,
flags
.
to_int
(
),
self
.
local_window_size
,
len
(
data
))
int
(
flags
),
self
.
local_window_size
,
len
(
data
))
packet
=
_TimestampPacket
(
header
,
data
,
time
.
clock
())
packet
=
_TimestampPacket
(
header
,
data
,
time
.
perf_counter
())
packet
.
set_checksum
()
packet
.
set_checksum
()
self
.
sent_ack_buffer
.
append
(
packet
)
self
.
sent_ack_buffer
.
append
(
packet
)
# noinspection PyProtectedMember
# noinspection PyProtectedMember
self
.
binding
.
_send
(
packet
,
self
.
remote_udp_addr
)
self
.
binding
.
_send
(
packet
,
self
.
remote_udp_addr
)
def
__send_ack
(
self
)
->
None
:
def
__send_ack
(
self
)
->
None
:
logging
.
debug
(
self
.
deb
+
f
"Sending ACK
{
self
.
recv_ack_nr
}
"
)
ack_header
=
_Header
(
self
.
local_port
,
self
.
remote_port
,
self
.
send_syn_nr
,
self
.
recv_ack_nr
,
ack_header
=
_Header
(
self
.
local_port
,
self
.
remote_port
,
self
.
send_syn_nr
,
self
.
recv_ack_nr
,
_Flags
((
False
,
True
,
False
))
.
to_int
(
),
self
.
local_window_size
,
0
)
int
(
_Flags
((
False
,
True
,
False
))),
self
.
local_window_size
,
0
)
ack_packet
=
_Packet
(
ack_header
,
b
""
)
ack_packet
=
_Packet
(
ack_header
,
b
""
)
ack_packet
.
set_checksum
()
ack_packet
.
set_checksum
()
...
@@ -164,41 +188,49 @@ class _Stream:
...
@@ -164,41 +188,49 @@ class _Stream:
self
.
binding
.
_send
(
ack_packet
,
self
.
remote_udp_addr
)
self
.
binding
.
_send
(
ack_packet
,
self
.
remote_udp_addr
)
def
pass_received_msgs
(
self
,
packets
:
List
[
_AddrPacket
])
->
None
:
def
pass_received_msgs
(
self
,
packets
:
List
[
_AddrPacket
])
->
None
:
logging
.
debug
(
self
.
deb
+
f
"Received
{
len
(
packets
)
}
messages"
)
send_ack
=
False
send_ack
=
False
for
packet
in
packets
:
for
packet
in
packets
:
if
packet
.
remote_udp_addr
!=
self
.
remote_udp_addr
:
print
(
"Wrong remote UDP address"
)
continue
flags
=
packet
.
header
.
flags_obj
()
flags
=
packet
.
header
.
flags_obj
()
logging
.
debug
(
self
.
deb
+
f
"Received #
{
packet
.
header
.
syn_nr
}
{
flags
}
"
f
"
{
f
' with ACK
{
packet
.
header
.
ack_nr
}
' if flags.ack else ''
}
"
f
"with
{
len
(
packet
.
data
)
}
bytes"
)
if
self
.
expect_syn_ack
:
if
self
.
expect_syn_ack
:
if
flags
.
syn
and
flags
.
ack
and
not
flags
.
fin
:
if
flags
.
syn
and
flags
.
ack
and
not
flags
.
fin
:
if
packet
.
header
.
ack_nr
!=
self
.
send_ack_nr
:
if
packet
.
header
.
ack_nr
!=
self
.
send_ack_nr
+
1
:
print
(
"Wrong ACK nr in SYN-ACK"
)
logging
.
warning
(
self
.
deb
+
"Wrong ACK nr in SYN-ACK"
)
continue
continue
self
.
expect_syn_ack
=
False
self
.
expect_syn_ack
=
False
self
.
recv_syn_nr
=
packet
.
header
.
syn_nr
self
.
recv_syn_nr
=
packet
.
header
.
syn_nr
self
.
recv_ack_nr
=
packet
.
header
.
syn_nr
+
1
self
.
recv_ack_nr
=
packet
.
header
.
syn_nr
+
1
self
.
first_recv_syn_nr
=
self
.
recv_syn_nr
self
.
remote_window_size
=
packet
.
header
.
window
self
.
remote_window_size
=
packet
.
header
.
window
else
:
else
:
print
(
"SYN-ACK expected"
)
logging
.
warning
(
self
.
deb
+
"SYN-ACK expected"
)
continue
continue
elif
flags
.
syn
:
elif
flags
.
syn
:
print
(
"Unexpected SYN"
)
if
packet
.
header
.
syn_nr
==
self
.
first_recv_syn_nr
:
# TODO simultaneous open?
logging
.
debug
(
self
.
deb
+
"Spurious SYN"
)
continue
continue
else
:
logging
.
warning
(
self
.
deb
+
"Unexpected SYN"
)
# TODO simultaneous open
continue
if
packet
.
header
.
syn_nr
>
self
.
recv_syn_nr
+
self
.
local_window_size
:
if
packet
.
header
.
syn_nr
>
self
.
recv_syn_nr
+
self
.
local_window_size
:
print
(
"Too high SYN"
)
logging
.
warning
(
self
.
deb
+
"Too high SYN"
)
continue
continue
# TODO if flags.fin:
# TODO if flags.fin:
if
flags
.
ack
:
if
flags
.
ack
:
if
packet
.
header
.
ack_nr
>
self
.
send_syn_nr
+
1
:
if
packet
.
header
.
ack_nr
>
self
.
send_syn_nr
+
1
:
print
(
"Too high ACK"
)
logging
.
warning
(
self
.
deb
+
"Too high ACK"
)
continue
if
packet
.
header
.
ack_nr
<
self
.
first_send_syn_nr
:
logging
.
warning
(
self
.
deb
+
"Too low ACK"
)
continue
continue
if
packet
.
header
.
ack_nr
>
self
.
send_ack_nr
:
if
packet
.
header
.
ack_nr
>
self
.
send_ack_nr
:
self
.
send_ack_nr
=
packet
.
header
.
ack_nr
self
.
send_ack_nr
=
packet
.
header
.
ack_nr
...
@@ -225,9 +257,12 @@ class _Stream:
...
@@ -225,9 +257,12 @@ class _Stream:
elif
packet
.
header
.
syn_nr
>
self
.
recv_ack_nr
:
elif
packet
.
header
.
syn_nr
>
self
.
recv_ack_nr
:
# There is a gap, store this packet
# There is a gap, store this packet
insert_index
=
bisect
.
bisect_left
(
self
.
receive_ack_buffer
,
packet
)
insert_index
=
bisect
.
bisect_left
(
self
.
receive_ack_buffer
,
packet
)
# Check for duplicate packet
if
(
insert_index
+
1
>
len
(
self
.
receive_ack_buffer
)
if
(
insert_index
+
1
<
len
(
self
.
receive_ack_buffer
)
or
self
.
receive_ack_buffer
[
insert_index
+
1
].
header
.
syn_nr
!=
packet
.
header
.
syn_nr
):
and
self
.
receive_ack_buffer
[
insert_index
+
1
].
header
.
syn_nr
==
packet
.
header
.
syn_nr
):
assert
self
.
receive_ack_buffer
[
insert_index
+
1
].
data
==
packet
.
data
logging
.
debug
(
self
.
deb
+
"Spurious retransmission"
)
else
:
self
.
receive_ack_buffer
.
insert
(
insert_index
,
packet
)
self
.
receive_ack_buffer
.
insert
(
insert_index
,
packet
)
send_ack
=
True
send_ack
=
True
...
@@ -237,27 +272,34 @@ class _Stream:
...
@@ -237,27 +272,34 @@ class _Stream:
self
.
send_buffer_lock
.
acquire
()
self
.
send_buffer_lock
.
acquire
()
data_to_send
=
len
(
self
.
send_buffer
)
>
0
data_to_send
=
len
(
self
.
send_buffer
)
>
0
self
.
send_buffer_lock
.
release
()
self
.
send_buffer_lock
.
release
()
if
data_to_send
:
if
not
data_to_send
:
self
.
__send_ack
()
self
.
__send_ack
()
def
poll_sender
(
self
)
->
None
:
def
poll_sender
(
self
)
->
None
:
lost
=
[
p
for
p
in
self
.
sent_ack_buffer
if
time
.
perf_counter
()
-
p
.
timestamp
>=
self
.
timeout
]
self
.
sent_ack_buffer
=
[
p
for
p
in
self
.
sent_ack_buffer
if
time
.
perf_counter
()
-
p
.
timestamp
<
self
.
timeout
]
if
len
(
lost
)
>
0
:
logging
.
debug
(
self
.
deb
+
f
"
{
len
(
lost
)
}
lost messages"
)
for
p
in
lost
:
p
.
timestamp
=
time
.
perf_counter
()
self
.
__send_to_be_acked
(
p
.
data
,
p
.
header
.
syn_nr
,
p
.
header
.
flags_obj
())
if
self
.
expect_syn_ack
:
if
self
.
expect_syn_ack
:
return
return
lost
=
[
p
for
p
in
self
.
sent_ack_buffer
if
time
.
clock
()
-
p
.
timestamp
>=
self
.
timeout
]
if
len
(
self
.
send_buffer
)
>
0
:
for
p
in
lost
:
logging
.
debug
(
self
.
deb
+
f
"
{
len
(
self
.
send_buffer
)
}
unsent messages"
)
p
.
timestamp
=
time
.
clock
()
self
.
__send_to_be_acked
(
p
.
data
,
p
.
header
.
syn_nr
,
p
.
header
.
flags_obj
())
if
len
(
self
.
sent_ack_buffer
)
<
self
.
remote_window_size
:
if
len
(
self
.
sent_ack_buffer
)
<
self
.
remote_window_size
:
self
.
send_buffer_lock
.
acquire
()
self
.
send_buffer_lock
.
acquire
()
while
len
(
self
.
sent_ack_buffer
)
<
self
.
remote_window_size
:
while
len
(
self
.
send_buffer
)
>
0
and
len
(
self
.
sent_ack_buffer
)
<
self
.
remote_window_size
:
data
=
self
.
send_buffer
.
pop
(
0
)
data
=
self
.
send_buffer
.
pop
(
0
)
self
.
__send_to_be_acked
(
data
+
bytes
(
payload_size
-
len
(
data
))
,
self
.
send_syn_nr
)
self
.
__send_to_be_acked
(
data
,
self
.
send_syn_nr
)
self
.
send_syn_nr
+=
1
self
.
send_syn_nr
+=
1
self
.
send_buffer_lock
.
release
()
self
.
send_buffer_lock
.
release
()
def
enqueue_data
(
self
,
data
:
bytes
)
->
None
:
def
enqueue_data
(
self
,
data
:
bytes
)
->
None
:
logging
.
debug
(
self
.
deb
+
f
"Enqueuing
{
len
(
data
)
}
bytes"
)
self
.
send_buffer_lock
.
acquire
()
self
.
send_buffer_lock
.
acquire
()
offset
=
0
offset
=
0
while
offset
<
len
(
data
):
while
offset
<
len
(
data
):
...
@@ -268,11 +310,12 @@ class _Stream:
...
@@ -268,11 +310,12 @@ class _Stream:
def
__at_least_received
(
self
,
count
:
int
)
->
bool
:
def
__at_least_received
(
self
,
count
:
int
)
->
bool
:
i
=
0
i
=
0
while
i
<
len
(
self
.
receive_buffer
)
and
count
>
0
:
while
i
<
len
(
self
.
receive_buffer
)
and
count
>
0
:
count
+
=
len
(
self
.
receive_buffer
[
i
])
count
-
=
len
(
self
.
receive_buffer
[
i
])
i
+=
1
i
+=
1
return
count
<=
0
return
count
<=
0
def
dequeue_data
(
self
,
count
:
int
)
->
bytes
:
def
dequeue_data
(
self
,
count
:
int
)
->
bytes
:
logging
.
debug
(
self
.
deb
+
f
"Dequeuing
{
count
}
bytes..."
)
self
.
receive_buffer_lock
.
acquire
()
self
.
receive_buffer_lock
.
acquire
()
self
.
receive_bytes_available
.
wait_for
(
lambda
:
self
.
__at_least_received
(
count
))
self
.
receive_bytes_available
.
wait_for
(
lambda
:
self
.
__at_least_received
(
count
))
...
@@ -281,10 +324,12 @@ class _Stream:
...
@@ -281,10 +324,12 @@ class _Stream:
data
+=
self
.
receive_buffer
.
pop
(
0
)
data
+=
self
.
receive_buffer
.
pop
(
0
)
if
len
(
data
)
<
count
:
if
len
(
data
)
<
count
:
data
+=
self
.
receive_buffer
[
0
][:
len
(
data
)
-
count
]
bytes_short
=
count
-
len
(
data
)
self
.
receive_buffer
[
0
]
=
self
.
receive_buffer
[
0
][
len
(
data
)
-
count
:]
data
+=
self
.
receive_buffer
[
0
][:
bytes_short
]
self
.
receive_buffer
[
0
]
=
self
.
receive_buffer
[
0
][
bytes_short
:]
self
.
receive_buffer_lock
.
release
()
self
.
receive_buffer_lock
.
release
()
logging
.
debug
(
self
.
deb
+
"Dequeued"
)
return
data
return
data
def
dequeue_is_available
(
self
,
count
:
int
)
->
bool
:
def
dequeue_is_available
(
self
,
count
:
int
)
->
bool
:
...
@@ -294,14 +339,17 @@ class _Stream:
...
@@ -294,14 +339,17 @@ class _Stream:
return
available
return
available
def
dequeue_all_data
(
self
)
->
bytes
:
def
dequeue_all_data
(
self
)
->
bytes
:
logging
.
debug
(
self
.
deb
+
"Dequeuing all data..."
)
self
.
receive_buffer_lock
.
acquire
()
self
.
receive_buffer_lock
.
acquire
()
data
=
b
""
data
=
b
""
while
len
(
self
.
receive_buffer
)
>
0
:
while
len
(
self
.
receive_buffer
)
>
0
:
data
+=
self
.
receive_buffer
.
pop
(
0
)
data
+=
self
.
receive_buffer
.
pop
(
0
)
self
.
receive_buffer_lock
.
release
()
self
.
receive_buffer_lock
.
release
()
logging
.
debug
(
self
.
deb
+
f
"Dequeued
{
len
(
data
)
}
bytes"
)
return
data
return
data
def
close
(
self
)
->
None
:
def
close
(
self
)
->
None
:
logging
.
debug
(
self
.
deb
+
"Close"
)
# TODO FIN
# TODO FIN
pass
pass
...
@@ -310,6 +358,11 @@ class _Connection:
...
@@ -310,6 +358,11 @@ class _Connection:
def
__init__
(
self
,
parent
:
Union
[
"_Server"
,
"Binding"
],
stream
:
_Stream
):
def
__init__
(
self
,
parent
:
Union
[
"_Server"
,
"Binding"
],
stream
:
_Stream
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
stream
=
stream
self
.
stream
=
stream
self
.
deb
=
f
"
{
self
.
__parent_str
()
}
|
{
self
.
stream
.
local_port
}
->
{
self
.
stream
.
remote_port
}
: "
logging
.
debug
(
self
.
deb
+
"Setup connection"
)
def
__parent_str
(
self
)
->
str
:
return
self
.
parent
.
local_udp_addr
if
type
(
self
.
parent
)
is
Binding
else
self
.
parent
.
binding
.
local_udp_addr
def
_pass_received_msgs
(
self
,
packets
:
List
[
_AddrPacket
])
->
None
:
def
_pass_received_msgs
(
self
,
packets
:
List
[
_AddrPacket
])
->
None
:
self
.
stream
.
pass_received_msgs
(
packets
)
self
.
stream
.
pass_received_msgs
(
packets
)
...
@@ -330,6 +383,7 @@ class _Connection:
...
@@ -330,6 +383,7 @@ class _Connection:
return
self
.
stream
.
dequeue_all_data
()
return
self
.
stream
.
dequeue_all_data
()
def
close
(
self
)
->
None
:
def
close
(
self
)
->
None
:
logging
.
debug
(
self
.
deb
+
"Close"
)
if
type
(
self
.
parent
)
is
_Server
:
if
type
(
self
.
parent
)
is
_Server
:
# noinspection PyProtectedMember
# noinspection PyProtectedMember
self
.
parent
.
_remove_stream
(
self
.
stream
.
remote_port
)
self
.
parent
.
_remove_stream
(
self
.
stream
.
remote_port
)
...
@@ -337,20 +391,27 @@ class _Connection:
...
@@ -337,20 +391,27 @@ class _Connection:
# noinspection PyProtectedMember
# noinspection PyProtectedMember
self
.
parent
.
_remove_socket
(
self
.
stream
.
remote_port
)
self
.
parent
.
_remove_socket
(
self
.
stream
.
remote_port
)
self
.
stream
.
close
()
self
.
stream
.
close
()
logging
.
debug
(
self
.
deb
+
"Closed"
)
class
_Server
:
class
_Server
:
def
__init__
(
self
,
binding
:
"Binding"
,
local_port
:
int
):
def
__init__
(
self
,
binding
:
"Binding"
,
local_port
:
int
):
self
.
binding
=
binding
self
.
binding
=
binding
self
.
local_port
=
local_port
self
.
local_port
=
local_port
self
.
streams
:
Dict
[
int
,
_Stream
]
=
{}
self
.
streams
:
Dict
[
int
,
_Stream
]
=
{}
# Remote port -> _Stream
self
.
streams_lock
=
threading
.
Lock
()
self
.
streams_lock
=
threading
.
Lock
()
self
.
backlog
=
0
self
.
backlog
=
0
self
.
backlog_connections
:
List
[
_Connection
]
=
[]
self
.
backlog_connections
:
List
[
_Connection
]
=
[]
self
.
backlog_lock
=
threading
.
Lock
()
self
.
backlog_lock
=
threading
.
Lock
()
self
.
backlog_available
=
threading
.
Condition
(
self
.
backlog_lock
)
self
.
deb
=
f
"
{
self
.
binding
.
local_udp_addr
}
|
{
self
.
local_port
}
: "
logging
.
debug
(
self
.
deb
+
"Setup server"
)
def
_pass_received_msgs
(
self
,
packets
:
List
[
_AddrPacket
])
->
None
:
def
_pass_received_msgs
(
self
,
packets
:
List
[
_AddrPacket
])
->
None
:
logging
.
debug
(
self
.
deb
+
f
"Received
{
len
(
packets
)
}
messages"
)
packet_batches
:
Dict
[
int
,
List
[
_AddrPacket
]]
=
{}
packet_batches
:
Dict
[
int
,
List
[
_AddrPacket
]]
=
{}
orphan_packets
:
List
[
_AddrPacket
]
=
[]
orphan_packets
:
List
[
_AddrPacket
]
=
[]
...
@@ -358,8 +419,7 @@ class _Server:
...
@@ -358,8 +419,7 @@ class _Server:
for
packet
in
packets
:
for
packet
in
packets
:
if
packet
.
header
.
src_port
in
self
.
streams
:
if
packet
.
header
.
src_port
in
self
.
streams
:
batch
=
packet_batches
[
packet
.
header
.
src_port
]
=
[]
packet_batches
.
setdefault
(
packet
.
header
.
src_port
,
[]).
append
(
packet
)
batch
.
append
(
packet
)
else
:
else
:
orphan_packets
.
append
(
packet
)
orphan_packets
.
append
(
packet
)
...
@@ -379,12 +439,13 @@ class _Server:
...
@@ -379,12 +439,13 @@ class _Server:
_RemoteInit
(
packet
.
header
.
syn_nr
,
packet
.
header
.
window
))
_RemoteInit
(
packet
.
header
.
syn_nr
,
packet
.
header
.
window
))
self
.
backlog_connections
.
append
(
_Connection
(
self
,
stream
))
self
.
backlog_connections
.
append
(
_Connection
(
self
,
stream
))
self
.
backlog
-=
1
self
.
backlog
-=
1
self
.
backlog_available
.
notify
()
self
.
backlog_lock
.
release
()
self
.
backlog_lock
.
release
()
else
:
else
:
self
.
backlog_lock
.
release
()
self
.
backlog_lock
.
release
()
print
(
"Backlog full"
)
logging
.
warning
(
self
.
deb
+
"Backlog full"
)
else
:
else
:
print
(
"Unknown connection"
)
logging
.
warning
(
self
.
deb
+
"Unknown connection"
)
self
.
streams_lock
.
release
()
self
.
streams_lock
.
release
()
...
@@ -395,47 +456,58 @@ class _Server:
...
@@ -395,47 +456,58 @@ class _Server:
self
.
streams_lock
.
release
()
self
.
streams_lock
.
release
()
def
_remove_stream
(
self
,
remote_port
:
int
)
->
None
:
def
_remove_stream
(
self
,
remote_port
:
int
)
->
None
:
logging
.
debug
(
self
.
deb
+
"Remove stream"
)
self
.
streams_lock
.
acquire
()