dockerfiles/anylink/dtls-2.0.9/e2e/e2e_test.go

330 lines
7.6 KiB
Go

// +build !js
package e2e
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"github.com/pion/transport/test"
)
const (
testMessage = "Hello World"
testTimeLimit = 5 * time.Second
messageRetry = 200 * time.Millisecond
)
var errServerTimeout = errors.New("waiting on serverReady err: timeout")
func randomPort(t testing.TB) int {
t.Helper()
conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to pickPort: %v", err)
}
defer func() {
_ = conn.Close()
}()
switch addr := conn.LocalAddr().(type) {
case *net.UDPAddr:
return addr.Port
default:
t.Fatalf("unknown addr type %T", addr)
return 0
}
}
func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter, messageRecvCount *uint64) {
go func() {
buffer := make([]byte, 8192)
n, err := conn.Read(buffer)
if err != nil {
errChan <- err
return
}
outChan <- string(buffer[:n])
atomic.AddUint64(messageRecvCount, 1)
}()
for {
if atomic.LoadUint64(messageRecvCount) == 2 {
break
} else if _, err := conn.Write([]byte(testMessage)); err != nil {
errChan <- err
break
}
time.Sleep(messageRetry)
}
}
type comm struct {
ctx context.Context
clientConfig, serverConfig *dtls.Config
serverPort int
messageRecvCount *uint64 // Counter to make sure both sides got a message
clientMutex *sync.Mutex
clientConn net.Conn
serverMutex *sync.Mutex
serverConn net.Conn
serverListener net.Listener
serverReady chan struct{}
errChan chan error
clientChan chan string
serverChan chan string
client func(*comm)
server func(*comm)
}
func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm)) *comm {
messageRecvCount := uint64(0)
c := &comm{
ctx: ctx,
clientConfig: clientConfig,
serverConfig: serverConfig,
serverPort: serverPort,
messageRecvCount: &messageRecvCount,
clientMutex: &sync.Mutex{},
serverMutex: &sync.Mutex{},
serverReady: make(chan struct{}),
errChan: make(chan error),
clientChan: make(chan string),
serverChan: make(chan string),
server: server,
client: client,
}
return c
}
func (c *comm) assert(t *testing.T) {
// DTLS Client
go c.client(c)
// DTLS Server
go c.server(c)
defer func() {
if c.clientConn != nil {
if err := c.clientConn.Close(); err != nil {
t.Fatal(err)
}
}
if c.serverConn != nil {
if err := c.serverConn.Close(); err != nil {
t.Fatal(err)
}
}
if c.serverListener != nil {
if err := c.serverListener.Close(); err != nil {
t.Fatal(err)
}
}
}()
func() {
seenClient, seenServer := false, false
for {
select {
case err := <-c.errChan:
t.Fatal(err)
case <-time.After(testTimeLimit):
t.Fatalf("Test timeout, seenClient %t seenServer %t", seenClient, seenServer)
case clientMsg := <-c.clientChan:
if clientMsg != testMessage {
t.Fatalf("clientMsg does not equal test message: %s %s", clientMsg, testMessage)
}
seenClient = true
if seenClient && seenServer {
return
}
case serverMsg := <-c.serverChan:
if serverMsg != testMessage {
t.Fatalf("serverMsg does not equal test message: %s %s", serverMsg, testMessage)
}
seenServer = true
if seenClient && seenServer {
return
}
}
}
}()
}
func clientPion(c *comm) {
select {
case <-c.serverReady:
// OK
case <-time.After(time.Second):
c.errChan <- errServerTimeout
}
c.clientMutex.Lock()
defer c.clientMutex.Unlock()
var err error
c.clientConn, err = dtls.DialWithContext(c.ctx, "udp",
&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
c.clientConfig,
)
if err != nil {
c.errChan <- err
return
}
simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount)
}
func serverPion(c *comm) {
c.serverMutex.Lock()
defer c.serverMutex.Unlock()
var err error
c.serverListener, err = dtls.Listen("udp",
&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
c.serverConfig,
)
if err != nil {
c.errChan <- err
return
}
c.serverReady <- struct{}{}
c.serverConn, err = c.serverListener.Accept()
if err != nil {
c.errChan <- err
return
}
simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount)
}
/*
Simple DTLS Client/Server can communicate
- Assert that you can send messages both ways
- Assert that Close() on both ends work
- Assert that no Goroutines are leaked
*/
func testPionE2ESimple(t *testing.T, server, client func(*comm)) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
report := test.CheckRoutines(t)
defer report()
for _, cipherSuite := range []dtls.CipherSuiteID{
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
} {
cipherSuite := cipherSuite
t.Run(cipherSuite.String(), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
if err != nil {
t.Fatal(err)
}
cfg := &dtls.Config{
Certificates: []tls.Certificate{cert},
CipherSuites: []dtls.CipherSuiteID{cipherSuite},
InsecureSkipVerify: true,
}
serverPort := randomPort(t)
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
comm.assert(t)
})
}
}
func testPionE2ESimplePSK(t *testing.T, server, client func(*comm)) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
report := test.CheckRoutines(t)
defer report()
for _, cipherSuite := range []dtls.CipherSuiteID{
dtls.TLS_PSK_WITH_AES_128_CCM,
dtls.TLS_PSK_WITH_AES_128_CCM_8,
dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
} {
cipherSuite := cipherSuite
t.Run(cipherSuite.String(), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
cfg := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
return []byte{0xAB, 0xC1, 0x23}, nil
},
PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
CipherSuites: []dtls.CipherSuiteID{cipherSuite},
}
serverPort := randomPort(t)
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
comm.assert(t)
})
}
}
func testPionE2EMTUs(t *testing.T, server, client func(*comm)) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
report := test.CheckRoutines(t)
defer report()
for _, mtu := range []int{
10000,
1000,
100,
} {
mtu := mtu
t.Run(fmt.Sprintf("MTU%d", mtu), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
if err != nil {
t.Fatal(err)
}
cfg := &dtls.Config{
Certificates: []tls.Certificate{cert},
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
InsecureSkipVerify: true,
MTU: mtu,
}
serverPort := randomPort(t)
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
comm.assert(t)
})
}
}
func TestPionE2ESimple(t *testing.T) {
testPionE2ESimple(t, serverPion, clientPion)
}
func TestPionE2ESimplePSK(t *testing.T) {
testPionE2ESimplePSK(t, serverPion, clientPion)
}
func TestPionE2EMTUs(t *testing.T) {
testPionE2EMTUs(t, serverPion, clientPion)
}