dockerfiles/anylink/dtls-2.0.9/handshaker_test.go

278 lines
7.3 KiB
Go

package dtls
import (
"bytes"
"context"
"crypto/tls"
"sync"
"testing"
"time"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
"github.com/pion/logging"
"github.com/pion/transport/test"
)
const nonZeroRetransmitInterval = 100 * time.Millisecond
// Test that writes to the key log are in the correct format and only applies
// when a key log writer is given.
func TestWriteKeyLog(t *testing.T) {
var buf bytes.Buffer
cfg := handshakeConfig{
keyLogWriter: &buf,
}
cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF})
// Secrets follow the format <Label> <space> <ClientRandom> <space> <Secret>
// https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format
want := "LABEL aabbcc ddeeff\n"
if buf.String() != want {
t.Fatalf("Got %s want %s", buf.String(), want)
}
// no key log writer = no writes
cfg = handshakeConfig{}
cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF})
}
func TestHandshaker(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
loggerFactory := logging.NewDefaultLoggerFactory()
logger := loggerFactory.NewLogger("dtls")
cipherSuites, err := parseCipherSuites(nil, nil, true, false)
if err != nil {
t.Fatal(err)
}
clientCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}
genFilters := map[string]func() (packetFilter, packetFilter, func(t *testing.T)){
"PassThrough": func() (packetFilter, packetFilter, func(t *testing.T)) {
return nil, nil, nil
},
"HelloVerifyRequestLost": func() (packetFilter, packetFilter, func(t *testing.T)) {
var (
cntHelloVerifyRequest = 0
cntClientHelloNoCookie = 0
)
const helloVerifyDrop = 5
return func(p *packet) bool {
h, ok := p.record.Content.(*handshake.Handshake)
if !ok {
return true
}
if hmch, ok := h.Message.(*handshake.MessageClientHello); ok {
if len(hmch.Cookie) == 0 {
cntClientHelloNoCookie++
}
}
return true
},
func(p *packet) bool {
h, ok := p.record.Content.(*handshake.Handshake)
if !ok {
return true
}
if _, ok := h.Message.(*handshake.MessageHelloVerifyRequest); ok {
cntHelloVerifyRequest++
return cntHelloVerifyRequest > helloVerifyDrop
}
return true
},
func(t *testing.T) {
if cntHelloVerifyRequest != helloVerifyDrop+1 {
t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest)
}
if cntClientHelloNoCookie != cntHelloVerifyRequest {
t.Errorf(
"HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times",
cntHelloVerifyRequest, cntClientHelloNoCookie,
)
}
}
},
}
for name, filters := range genFilters {
f1, f2, report := filters()
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if report != nil {
defer report(t)
}
ca, cb := flightTestPipe(ctx, f1, f2)
ca.state.isClient = true
var wg sync.WaitGroup
wg.Add(2)
ctxCliFinished, cancelCli := context.WithCancel(ctx)
ctxSrvFinished, cancelSrv := context.WithCancel(ctx)
go func() {
defer wg.Done()
cfg := &handshakeConfig{
localCipherSuites: cipherSuites,
localCertificates: []tls.Certificate{clientCert},
localSignatureSchemes: signaturehash.Algorithms(),
insecureSkipVerify: true,
log: logger,
onFlightState: func(f flightVal, s handshakeState) {
if s == handshakeFinished {
cancelCli()
}
},
retransmitInterval: nonZeroRetransmitInterval,
}
fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1)
switch err := fsm.Run(ctx, ca, handshakePreparing); err {
case context.Canceled:
case context.DeadlineExceeded:
t.Error("Timeout")
default:
t.Error(err)
}
}()
go func() {
defer wg.Done()
cfg := &handshakeConfig{
localCipherSuites: cipherSuites,
localCertificates: []tls.Certificate{clientCert},
localSignatureSchemes: signaturehash.Algorithms(),
insecureSkipVerify: true,
log: logger,
onFlightState: func(f flightVal, s handshakeState) {
if s == handshakeFinished {
cancelSrv()
}
},
retransmitInterval: nonZeroRetransmitInterval,
}
fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0)
switch err := fsm.Run(ctx, cb, handshakePreparing); err {
case context.Canceled:
case context.DeadlineExceeded:
t.Error("Timeout")
default:
t.Error(err)
}
}()
<-ctxCliFinished.Done()
<-ctxSrvFinished.Done()
cancel()
wg.Wait()
})
}
}
type packetFilter func(*packet) bool
func flightTestPipe(ctx context.Context, filter1 packetFilter, filter2 packetFilter) (*flightTestConn, *flightTestConn) {
ca := newHandshakeCache()
cb := newHandshakeCache()
chA := make(chan chan struct{})
chB := make(chan chan struct{})
return &flightTestConn{
handshakeCache: ca,
otherEndCache: cb,
recv: chA,
otherEndRecv: chB,
done: ctx.Done(),
filter: filter1,
}, &flightTestConn{
handshakeCache: cb,
otherEndCache: ca,
recv: chB,
otherEndRecv: chA,
done: ctx.Done(),
filter: filter2,
}
}
type flightTestConn struct {
state State
handshakeCache *handshakeCache
recv chan chan struct{}
done <-chan struct{}
epoch uint16
filter packetFilter
otherEndCache *handshakeCache
otherEndRecv chan chan struct{}
}
func (c *flightTestConn) recvHandshake() <-chan chan struct{} {
return c.recv
}
func (c *flightTestConn) setLocalEpoch(epoch uint16) {
c.epoch = epoch
}
func (c *flightTestConn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
return nil
}
func (c *flightTestConn) writePackets(ctx context.Context, pkts []*packet) error {
for _, p := range pkts {
if c.filter != nil && !c.filter(p) {
continue
}
if h, ok := p.record.Content.(*handshake.Handshake); ok {
handshakeRaw, err := p.record.Marshal()
if err != nil {
return err
}
c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
content, err := h.Message.Marshal()
if err != nil {
return err
}
h.Header.Length = uint32(len(content))
h.Header.FragmentLength = uint32(len(content))
hdr, err := h.Header.Marshal()
if err != nil {
return err
}
c.otherEndCache.push(
append(hdr, content...), p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
}
}
go func() {
select {
case c.otherEndRecv <- make(chan struct{}):
case <-c.done:
}
}()
// Avoid deadlock on JS/WASM environment due to context switch problem.
time.Sleep(10 * time.Millisecond)
return nil
}
func (c *flightTestConn) handleQueuedPackets(ctx context.Context) error {
return nil
}