dockerfiles/anylink/dtls-2.0.9/resume_test.go

209 lines
3.9 KiB
Go

package dtls
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"github.com/pion/transport/test"
)
var errMessageMissmatch = errors.New("messages missmatch")
func TestResumeClient(t *testing.T) {
DoTestResume(t, Client, Server)
}
func TestResumeServer(t *testing.T) {
DoTestResume(t, Server, Client)
}
func fatal(t *testing.T, errChan chan error, err error) {
close(errChan)
t.Fatal(err)
}
func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Conn, error)) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
certificate, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}
// Generate connections
localConn1, rc1 := net.Pipe()
localConn2, rc2 := net.Pipe()
remoteConn := &backupConn{curr: rc1, next: rc2}
// Launch remote in another goroutine
errChan := make(chan error, 1)
defer func() {
err = <-errChan
if err != nil {
t.Fatal(err)
}
}()
config := &Config{
Certificates: []tls.Certificate{certificate},
InsecureSkipVerify: true,
ExtendedMasterSecret: RequireExtendedMasterSecret,
}
go func() {
var remote *Conn
var errR error
remote, errR = newRemote(remoteConn, config)
if errR != nil {
errChan <- errR
}
// Loop of read write
for i := 0; i < 2; i++ {
recv := make([]byte, 1024)
var n int
n, errR = remote.Read(recv)
if errR != nil {
errChan <- errR
}
if _, errR = remote.Write(recv[:n]); errR != nil {
errChan <- errR
}
}
errChan <- nil
}()
var local *Conn
local, err = newLocal(localConn1, config)
if err != nil {
fatal(t, errChan, err)
}
defer func() {
_ = local.Close()
}()
// Test write and read
message := []byte("Hello")
if _, err = local.Write(message); err != nil {
fatal(t, errChan, err)
}
recv := make([]byte, 1024)
var n int
n, err = local.Read(recv)
if err != nil {
fatal(t, errChan, err)
}
if !bytes.Equal(message, recv[:n]) {
fatal(t, errChan, fmt.Errorf("%w: %s != %s", errMessageMissmatch, message, recv[:n]))
}
if err = localConn1.Close(); err != nil {
fatal(t, errChan, err)
}
// Serialize and deserialize state
state := local.ConnectionState()
var b []byte
b, err = state.MarshalBinary()
if err != nil {
fatal(t, errChan, err)
}
deserialized := &State{}
if err = deserialized.UnmarshalBinary(b); err != nil {
fatal(t, errChan, err)
}
// Resume dtls connection
var resumed net.Conn
resumed, err = Resume(deserialized, localConn2, config)
if err != nil {
fatal(t, errChan, err)
}
defer func() {
_ = resumed.Close()
}()
// Test write and read on resumed connection
if _, err = resumed.Write(message); err != nil {
fatal(t, errChan, err)
}
recv = make([]byte, 1024)
n, err = resumed.Read(recv)
if err != nil {
fatal(t, errChan, err)
}
if !bytes.Equal(message, recv[:n]) {
fatal(t, errChan, fmt.Errorf("%w: %s != %s", errMessageMissmatch, message, recv[:n]))
}
}
type backupConn struct {
curr net.Conn
next net.Conn
mux sync.Mutex
}
func (b *backupConn) Read(data []byte) (n int, err error) {
n, err = b.curr.Read(data)
if err != nil && b.next != nil {
b.mux.Lock()
b.curr = b.next
b.next = nil
b.mux.Unlock()
return b.Read(data)
}
return n, err
}
func (b *backupConn) Write(data []byte) (n int, err error) {
n, err = b.curr.Write(data)
if err != nil && b.next != nil {
b.mux.Lock()
b.curr = b.next
b.next = nil
b.mux.Unlock()
return b.Write(data)
}
return n, err
}
func (b *backupConn) Close() error {
return nil
}
func (b *backupConn) LocalAddr() net.Addr {
return nil
}
func (b *backupConn) RemoteAddr() net.Addr {
return nil
}
func (b *backupConn) SetDeadline(t time.Time) error {
return nil
}
func (b *backupConn) SetReadDeadline(t time.Time) error {
return nil
}
func (b *backupConn) SetWriteDeadline(t time.Time) error {
return nil
}