dockerfiles/anylink/dtls-2.0.9/conn_test.go

2027 lines
55 KiB
Go

package dtls
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/pion/dtls/v2/internal/ciphersuite"
"github.com/pion/dtls/v2/internal/net/dpipe"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/crypto/hash"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"github.com/pion/dtls/v2/pkg/crypto/signature"
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/extension"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
"github.com/pion/transport/test"
)
var (
errTestPSKInvalidIdentity = errors.New("TestPSK: Server got invalid identity")
errPSKRejected = errors.New("PSK Rejected")
errNotExpectedChain = errors.New("not expected chain")
errExpecedChain = errors.New("expected chain")
errWrongCert = errors.New("wrong cert")
)
func TestStressDuplex(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
// Run the test
stressDuplex(t)
}
func stressDuplex(t *testing.T) {
ca, cb, err := pipeMemory()
if err != nil {
t.Fatal(err)
}
defer func() {
err = ca.Close()
if err != nil {
t.Fatal(err)
}
err = cb.Close()
if err != nil {
t.Fatal(err)
}
}()
opt := test.Options{
MsgSize: 2048,
MsgCount: 100,
}
err = test.StressDuplex(ca, cb, opt)
if err != nil {
t.Fatal(err)
}
}
func TestRoutineLeakOnClose(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(5 * time.Second)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
ca, cb, err := pipeMemory()
if err != nil {
t.Fatal(err)
}
if _, err := ca.Write(make([]byte, 100)); err != nil {
t.Fatal(err)
}
if err := cb.Close(); err != nil {
t.Fatal(err)
}
if err := ca.Close(); err != nil {
t.Fatal(err)
}
// Packet is sent, but not read.
// inboundLoop routine should not be leaked.
}
func TestReadWriteDeadline(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(5 * time.Second)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
ca, cb, err := pipeMemory()
if err != nil {
t.Fatal(err)
}
if err := ca.SetDeadline(time.Unix(0, 1)); err != nil {
t.Fatal(err)
}
_, werr := ca.Write(make([]byte, 100))
if e, ok := werr.(net.Error); ok {
if !e.Timeout() {
t.Error("Deadline exceeded Write must return Timeout error")
}
if !e.Temporary() {
t.Error("Deadline exceeded Write must return Temporary error")
}
} else {
t.Error("Write must return net.Error error")
}
_, rerr := ca.Read(make([]byte, 100))
if e, ok := rerr.(net.Error); ok {
if !e.Timeout() {
t.Error("Deadline exceeded Read must return Timeout error")
}
if !e.Temporary() {
t.Error("Deadline exceeded Read must return Temporary error")
}
} else {
t.Error("Read must return net.Error error")
}
if err := ca.SetDeadline(time.Time{}); err != nil {
t.Error(err)
}
if err := ca.Close(); err != nil {
t.Error(err)
}
if err := cb.Close(); err != nil {
t.Error(err)
}
if _, err := ca.Write(make([]byte, 100)); !errors.Is(err, ErrConnClosed) {
t.Errorf("Write must return %v after close, got %v", ErrConnClosed, err)
}
if _, err := ca.Read(make([]byte, 100)); !errors.Is(err, io.EOF) {
t.Errorf("Read must return %v after close, got %v", io.EOF, err)
}
}
func TestSequenceNumberOverflow(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(5 * time.Second)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
t.Run("ApplicationData", func(t *testing.T) {
ca, cb, err := pipeMemory()
if err != nil {
t.Fatal(err)
}
atomic.StoreUint64(&ca.state.localSequenceNumber[1], recordlayer.MaxSequenceNumber)
if _, werr := ca.Write(make([]byte, 100)); werr != nil {
t.Errorf("Write must send message with maximum sequence number, but errord: %v", werr)
}
if _, werr := ca.Write(make([]byte, 100)); !errors.Is(werr, errSequenceNumberOverflow) {
t.Errorf("Write must abandonsend message with maximum sequence number, but errord: %v", werr)
}
if err := ca.Close(); err != nil {
t.Error(err)
}
if err := cb.Close(); err != nil {
t.Error(err)
}
})
t.Run("Handshake", func(t *testing.T) {
ca, cb, err := pipeMemory()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
atomic.StoreUint64(&ca.state.localSequenceNumber[0], recordlayer.MaxSequenceNumber+1)
// Try to send handshake packet.
if werr := ca.writePackets(ctx, []*packet{
{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: &handshake.Handshake{
Message: &handshake.MessageClientHello{
Version: protocol.Version1_2,
Cookie: make([]byte, 64),
CipherSuiteIDs: cipherSuiteIDs(defaultCipherSuites()),
CompressionMethods: defaultCompressionMethods(),
},
},
},
},
}); !errors.Is(werr, errSequenceNumberOverflow) {
t.Errorf("Connection must fail on handshake packet reaches maximum sequence number")
}
if err := ca.Close(); err != nil {
t.Error(err)
}
if err := cb.Close(); err != nil {
t.Error(err)
}
})
}
func pipeMemory() (*Conn, *Conn, error) {
// In memory pipe
ca, cb := dpipe.Pipe()
return pipeConn(ca, cb)
}
func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) {
type result struct {
c *Conn
err error
}
c := make(chan result)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Setup client
go func() {
client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true)
c <- result{client, err}
}()
// Setup server
server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true)
if err != nil {
return nil, nil, err
}
// Receive client
res := <-c
if res.err != nil {
return nil, nil, res.err
}
return res.c, server, nil
}
func testClient(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) {
if generateCertificate {
clientCert, err := selfsign.GenerateSelfSigned()
if err != nil {
return nil, err
}
cfg.Certificates = []tls.Certificate{clientCert}
}
cfg.InsecureSkipVerify = true
return ClientWithContext(ctx, c, cfg)
}
func testServer(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) {
if generateCertificate {
serverCert, err := selfsign.GenerateSelfSigned()
if err != nil {
return nil, err
}
cfg.Certificates = []tls.Certificate{serverCert}
}
return ServerWithContext(ctx, c, cfg)
}
func TestHandshakeWithAlert(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cases := map[string]struct {
configServer, configClient *Config
errServer, errClient error
}{
"CipherSuiteNoIntersection": {
configServer: &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
},
configClient: &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
},
errServer: errCipherSuiteNoIntersection,
errClient: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
},
"SignatureSchemesNoIntersection": {
configServer: &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP256AndSHA256},
},
configClient: &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512},
},
errServer: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
errClient: errNoAvailableSignatureSchemes,
},
}
for name, testCase := range cases {
testCase := testCase
t.Run(name, func(t *testing.T) {
clientErr := make(chan error, 1)
ca, cb := dpipe.Pipe()
go func() {
_, err := testClient(ctx, ca, testCase.configClient, true)
clientErr <- err
}()
_, errServer := testServer(ctx, cb, testCase.configServer, true)
if !errors.Is(errServer, testCase.errServer) {
t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer)
}
errClient := <-clientErr
if !errors.Is(errClient, testCase.errClient) {
t.Fatalf("Client error exp(%v) failed(%v)", testCase.errClient, errClient)
}
})
}
}
func TestExportKeyingMaterial(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
var rand [28]byte
exportLabel := "EXTRACTOR-dtls_srtp"
expectedServerKey := []byte{0x61, 0x09, 0x9d, 0x7d, 0xcb, 0x08, 0x52, 0x2c, 0xe7, 0x7b}
expectedClientKey := []byte{0x87, 0xf0, 0x40, 0x02, 0xf6, 0x1c, 0xf1, 0xfe, 0x8c, 0x77}
c := &Conn{
state: State{
localRandom: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand},
remoteRandom: handshake.Random{GMTUnixTime: time.Unix(1000, 0), RandomBytes: rand},
localSequenceNumber: []uint64{0, 0},
cipherSuite: &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{},
},
}
c.setLocalEpoch(0)
c.setRemoteEpoch(0)
state := c.ConnectionState()
_, err := state.ExportKeyingMaterial(exportLabel, nil, 0)
if !errors.Is(err, errHandshakeInProgress) {
t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err)
}
c.setLocalEpoch(1)
state = c.ConnectionState()
_, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0)
if !errors.Is(err, errContextUnsupported) {
t.Errorf("ExportKeyingMaterial with context: expected '%s' actual '%s'", errContextUnsupported, err)
}
for k := range invalidKeyingLabels() {
state = c.ConnectionState()
_, err = state.ExportKeyingMaterial(k, nil, 0)
if !errors.Is(err, errReservedExportKeyingMaterial) {
t.Errorf("ExportKeyingMaterial reserved label: expected '%s' actual '%s'", errReservedExportKeyingMaterial, err)
}
}
state = c.ConnectionState()
keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10)
if err != nil {
t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
} else if !bytes.Equal(keyingMaterial, expectedServerKey) {
t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedServerKey, keyingMaterial)
}
c.state.isClient = true
state = c.ConnectionState()
keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10)
if err != nil {
t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
} else if !bytes.Equal(keyingMaterial, expectedClientKey) {
t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedClientKey, keyingMaterial)
}
}
func TestPSK(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
for _, test := range []struct {
Name string
ServerIdentity []byte
CipherSuites []CipherSuiteID
}{
{
Name: "Server identity specified",
ServerIdentity: []byte("Test Identity"),
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
},
{
Name: "Server identity nil",
ServerIdentity: nil,
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
},
{
Name: "TLS_PSK_WITH_AES_128_CBC_SHA256",
ServerIdentity: nil,
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CBC_SHA256},
},
} {
test := test
t.Run(test.Name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
clientIdentity := []byte("Client Identity")
type result struct {
c *Conn
err error
}
clientRes := make(chan result, 1)
ca, cb := dpipe.Pipe()
go func() {
conf := &Config{
PSK: func(hint []byte) ([]byte, error) {
if !bytes.Equal(test.ServerIdentity, hint) { // nolint
return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) // nolint
}
return []byte{0xAB, 0xC1, 0x23}, nil
},
PSKIdentityHint: clientIdentity,
CipherSuites: test.CipherSuites,
}
c, err := testClient(ctx, ca, conf, false)
clientRes <- result{c, err}
}()
config := &Config{
PSK: func(hint []byte) ([]byte, error) {
if !bytes.Equal(clientIdentity, hint) {
return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, clientIdentity, hint)
}
return []byte{0xAB, 0xC1, 0x23}, nil
},
PSKIdentityHint: test.ServerIdentity,
CipherSuites: test.CipherSuites,
}
server, err := testServer(ctx, cb, config, false)
if err != nil {
t.Fatalf("TestPSK: Server failed(%v)", err)
}
actualPSKIdentityHint := server.ConnectionState().IdentityHint
if !bytes.Equal(actualPSKIdentityHint, clientIdentity) {
t.Errorf("TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", test.Name, clientIdentity, actualPSKIdentityHint)
}
defer func() {
_ = server.Close()
}()
res := <-clientRes
if res.err != nil {
t.Fatal(res.err)
}
_ = res.c.Close()
})
}
}
func TestPSKHintFail(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
serverAlertError := &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InternalError}}
pskRejected := errPSKRejected
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
clientErr := make(chan error, 1)
ca, cb := dpipe.Pipe()
go func() {
conf := &Config{
PSK: func(hint []byte) ([]byte, error) {
return nil, pskRejected
},
PSKIdentityHint: []byte{},
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
}
_, err := testClient(ctx, ca, conf, false)
clientErr <- err
}()
config := &Config{
PSK: func(hint []byte) ([]byte, error) {
return nil, pskRejected
},
PSKIdentityHint: []byte{},
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
}
if _, err := testServer(ctx, cb, config, false); !errors.Is(err, serverAlertError) {
t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err)
}
if err := <-clientErr; !errors.Is(err, pskRejected) {
t.Fatalf("TestPSK: Client error exp(%v) failed(%v)", pskRejected, err)
}
}
func TestClientTimeout(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
clientErr := make(chan error, 1)
ca, _ := dpipe.Pipe()
go func() {
conf := &Config{}
c, err := testClient(ctx, ca, conf, true)
if err == nil {
_ = c.Close()
}
clientErr <- err
}()
// no server!
err := <-clientErr
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
t.Fatalf("Client error exp(Temporary network error) failed(%v)", err)
}
}
func TestSRTPConfiguration(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
for _, test := range []struct {
Name string
ClientSRTP []SRTPProtectionProfile
ServerSRTP []SRTPProtectionProfile
ExpectedProfile SRTPProtectionProfile
WantClientError error
WantServerError error
}{
{
Name: "No SRTP in use",
ClientSRTP: nil,
ServerSRTP: nil,
ExpectedProfile: 0,
WantClientError: nil,
WantServerError: nil,
},
{
Name: "SRTP both ends",
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
WantClientError: nil,
WantServerError: nil,
},
{
Name: "SRTP client only",
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ServerSRTP: nil,
ExpectedProfile: 0,
WantClientError: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
WantServerError: errServerNoMatchingSRTPProfile,
},
{
Name: "SRTP server only",
ClientSRTP: nil,
ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ExpectedProfile: 0,
WantClientError: nil,
WantServerError: nil,
},
{
Name: "Multiple Suites",
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32},
ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32},
ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
WantClientError: nil,
WantServerError: nil,
},
{
Name: "Multiple Suites, Client Chooses",
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32},
ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_32, SRTP_AES128_CM_HMAC_SHA1_80},
ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
WantClientError: nil,
WantServerError: nil,
},
} {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ca, cb := dpipe.Pipe()
type result struct {
c *Conn
err error
}
c := make(chan result)
go func() {
client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: test.ClientSRTP}, true)
c <- result{client, err}
}()
server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}, true)
if !errors.Is(err, test.WantServerError) {
t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
}
if err == nil {
defer func() {
_ = server.Close()
}()
}
res := <-c
if res.err == nil {
defer func() {
_ = res.c.Close()
}()
}
if !errors.Is(res.err, test.WantClientError) {
t.Fatalf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
}
if res.c == nil {
return
}
actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile()
if actualClientSRTP != test.ExpectedProfile {
t.Errorf("TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualClientSRTP)
}
actualServerSRTP, _ := server.SelectedSRTPProtectionProfile()
if actualServerSRTP != test.ExpectedProfile {
t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP)
}
}
}
func TestClientCertificate(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
srvCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}
srvCAPool := x509.NewCertPool()
srvCertificate, err := x509.ParseCertificate(srvCert.Certificate[0])
if err != nil {
t.Fatal(err)
}
srvCAPool.AddCert(srvCertificate)
cert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}
certificate, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
t.Fatal(err)
}
caPool := x509.NewCertPool()
caPool.AddCert(certificate)
t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak
tests := map[string]struct {
clientCfg *Config
serverCfg *Config
wantErr bool
}{
"NoClientCert": {
clientCfg: &Config{RootCAs: srvCAPool},
serverCfg: &Config{
Certificates: []tls.Certificate{srvCert},
ClientAuth: NoClientCert,
ClientCAs: caPool,
},
},
"NoClientCert_cert": {
clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
serverCfg: &Config{
Certificates: []tls.Certificate{srvCert},
ClientAuth: RequireAnyClientCert,
},
},
"RequestClientCert_cert": {
clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
serverCfg: &Config{
Certificates: []tls.Certificate{srvCert},
ClientAuth: RequestClientCert,
},
},
"RequestClientCert_no_cert": {
clientCfg: &Config{RootCAs: srvCAPool},
serverCfg: &Config{
Certificates: []tls.Certificate{srvCert},
ClientAuth: RequestClientCert,
ClientCAs: caPool,
},
},
"RequireAnyClientCert": {
clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
serverCfg: &Config{
Certificates: []tls.Certificate{srvCert},
ClientAuth: RequireAnyClientCert,
},
},
"RequireAnyClientCert_error": {
clientCfg: &Config{RootCAs: srvCAPool},
serverCfg: &Config{
Certificates: []tls.Certificate{srvCert},
ClientAuth: RequireAnyClientCert,
},
wantErr: true,
},
"VerifyClientCertIfGiven_no_cert": {
clientCfg: &Config{RootCAs: srvCAPool},
serverCfg: &Config{
Certificates: []tls.Certificate{srvCert},
ClientAuth: VerifyClientCertIfGiven,
ClientCAs: caPool,
},
},
"VerifyClientCertIfGiven_cert": {
clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
serverCfg: &Config{
Certificates: []tls.Certificate{srvCert},
ClientAuth: VerifyClientCertIfGiven,
ClientCAs: caPool,
},
},
"VerifyClientCertIfGiven_error": {
clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
serverCfg: &Config{
Certificates: []tls.Certificate{srvCert},
ClientAuth: VerifyClientCertIfGiven,
},
wantErr: true,
},
"RequireAndVerifyClientCert": {
clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
serverCfg: &Config{
Certificates: []tls.Certificate{srvCert},
ClientAuth: RequireAndVerifyClientCert,
ClientCAs: caPool,
},
},
}
for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()
ca, cb := dpipe.Pipe()
type result struct {
c *Conn
err error
}
c := make(chan result)
go func() {
client, err := Client(ca, tt.clientCfg)
c <- result{client, err}
}()
server, err := Server(cb, tt.serverCfg)
res := <-c
defer func() {
if err == nil {
_ = server.Close()
}
if res.err == nil {
_ = res.c.Close()
}
}()
if tt.wantErr {
if err != nil {
// Error expected, test succeeded
return
}
t.Error("Error expected")
}
if err != nil {
t.Errorf("Server failed(%v)", err)
}
if res.err != nil {
t.Errorf("Client failed(%v)", res.err)
}
actualClientCert := server.ConnectionState().PeerCertificates
if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert {
if actualClientCert == nil {
t.Errorf("Client did not provide a certificate")
}
if len(actualClientCert) != len(tt.clientCfg.Certificates[0].Certificate) || !bytes.Equal(tt.clientCfg.Certificates[0].Certificate[0], actualClientCert[0]) {
t.Errorf("Client certificate was not communicated correctly")
}
}
if tt.serverCfg.ClientAuth == NoClientCert {
if actualClientCert != nil {
t.Errorf("Client certificate wasn't expected")
}
}
actualServerCert := res.c.ConnectionState().PeerCertificates
if actualServerCert == nil {
t.Errorf("Server did not provide a certificate")
}
if len(actualServerCert) != len(tt.serverCfg.Certificates[0].Certificate) || !bytes.Equal(tt.serverCfg.Certificates[0].Certificate[0], actualServerCert[0]) {
t.Errorf("Server certificate was not communicated correctly")
}
})
}
})
}
func TestExtendedMasterSecret(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
tests := map[string]struct {
clientCfg *Config
serverCfg *Config
expectedClientErr error
expectedServerErr error
}{
"Request_Request_ExtendedMasterSecret": {
clientCfg: &Config{
ExtendedMasterSecret: RequestExtendedMasterSecret,
},
serverCfg: &Config{
ExtendedMasterSecret: RequestExtendedMasterSecret,
},
expectedClientErr: nil,
expectedServerErr: nil,
},
"Request_Require_ExtendedMasterSecret": {
clientCfg: &Config{
ExtendedMasterSecret: RequestExtendedMasterSecret,
},
serverCfg: &Config{
ExtendedMasterSecret: RequireExtendedMasterSecret,
},
expectedClientErr: nil,
expectedServerErr: nil,
},
"Request_Disable_ExtendedMasterSecret": {
clientCfg: &Config{
ExtendedMasterSecret: RequestExtendedMasterSecret,
},
serverCfg: &Config{
ExtendedMasterSecret: DisableExtendedMasterSecret,
},
expectedClientErr: nil,
expectedServerErr: nil,
},
"Require_Request_ExtendedMasterSecret": {
clientCfg: &Config{
ExtendedMasterSecret: RequireExtendedMasterSecret,
},
serverCfg: &Config{
ExtendedMasterSecret: RequestExtendedMasterSecret,
},
expectedClientErr: nil,
expectedServerErr: nil,
},
"Require_Require_ExtendedMasterSecret": {
clientCfg: &Config{
ExtendedMasterSecret: RequireExtendedMasterSecret,
},
serverCfg: &Config{
ExtendedMasterSecret: RequireExtendedMasterSecret,
},
expectedClientErr: nil,
expectedServerErr: nil,
},
"Require_Disable_ExtendedMasterSecret": {
clientCfg: &Config{
ExtendedMasterSecret: RequireExtendedMasterSecret,
},
serverCfg: &Config{
ExtendedMasterSecret: DisableExtendedMasterSecret,
},
expectedClientErr: errClientRequiredButNoServerEMS,
expectedServerErr: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
},
"Disable_Request_ExtendedMasterSecret": {
clientCfg: &Config{
ExtendedMasterSecret: DisableExtendedMasterSecret,
},
serverCfg: &Config{
ExtendedMasterSecret: RequestExtendedMasterSecret,
},
expectedClientErr: nil,
expectedServerErr: nil,
},
"Disable_Require_ExtendedMasterSecret": {
clientCfg: &Config{
ExtendedMasterSecret: DisableExtendedMasterSecret,
},
serverCfg: &Config{
ExtendedMasterSecret: RequireExtendedMasterSecret,
},
expectedClientErr: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
expectedServerErr: errServerRequiredButNoClientEMS,
},
"Disable_Disable_ExtendedMasterSecret": {
clientCfg: &Config{
ExtendedMasterSecret: DisableExtendedMasterSecret,
},
serverCfg: &Config{
ExtendedMasterSecret: DisableExtendedMasterSecret,
},
expectedClientErr: nil,
expectedServerErr: nil,
},
}
for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ca, cb := dpipe.Pipe()
type result struct {
c *Conn
err error
}
c := make(chan result)
go func() {
client, err := testClient(ctx, ca, tt.clientCfg, true)
c <- result{client, err}
}()
server, err := testServer(ctx, cb, tt.serverCfg, true)
res := <-c
defer func() {
if err == nil {
_ = server.Close()
}
if res.err == nil {
_ = res.c.Close()
}
}()
if !errors.Is(res.err, tt.expectedClientErr) {
t.Errorf("Client error expected: \"%v\" but got \"%v\"", tt.expectedClientErr, res.err)
}
if !errors.Is(err, tt.expectedServerErr) {
t.Errorf("Server error expected: \"%v\" but got \"%v\"", tt.expectedServerErr, err)
}
})
}
}
func TestServerCertificate(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
cert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}
certificate, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
t.Fatal(err)
}
caPool := x509.NewCertPool()
caPool.AddCert(certificate)
t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak
tests := map[string]struct {
clientCfg *Config
serverCfg *Config
wantErr bool
}{
"no_ca": {
clientCfg: &Config{},
serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
wantErr: true,
},
"good_ca": {
clientCfg: &Config{RootCAs: caPool},
serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
},
"no_ca_skip_verify": {
clientCfg: &Config{InsecureSkipVerify: true},
serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
},
"good_ca_skip_verify_custom_verify_peer": {
clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}},
serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: RequireAnyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error {
if len(chain) != 0 {
return errNotExpectedChain
}
return nil
}},
},
"good_ca_verify_custom_verify_peer": {
clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}},
serverCfg: &Config{ClientCAs: caPool, Certificates: []tls.Certificate{cert}, ClientAuth: RequireAndVerifyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error {
if len(chain) == 0 {
return errExpecedChain
}
return nil
}},
},
"good_ca_custom_verify_peer": {
clientCfg: &Config{
RootCAs: caPool,
VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
return errWrongCert
},
},
serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
wantErr: true,
},
"server_name": {
clientCfg: &Config{RootCAs: caPool, ServerName: certificate.Subject.CommonName},
serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
},
"server_name_error": {
clientCfg: &Config{RootCAs: caPool, ServerName: "barfoo"},
serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
wantErr: true,
},
}
for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()
ca, cb := dpipe.Pipe()
type result struct {
c *Conn
err error
}
srvCh := make(chan result)
go func() {
s, err := Server(cb, tt.serverCfg)
srvCh <- result{s, err}
}()
cli, err := Client(ca, tt.clientCfg)
if err == nil {
_ = cli.Close()
}
if !tt.wantErr && err != nil {
t.Errorf("Client failed(%v)", err)
}
if tt.wantErr && err == nil {
t.Fatal("Error expected")
}
srv := <-srvCh
if srv.err == nil {
_ = srv.c.Close()
}
})
}
})
}
func TestCipherSuiteConfiguration(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
for _, test := range []struct {
Name string
ClientCipherSuites []CipherSuiteID
ServerCipherSuites []CipherSuiteID
WantClientError error
WantServerError error
WantSelectedCipherSuite CipherSuiteID
}{
{
Name: "No CipherSuites specified",
ClientCipherSuites: nil,
ServerCipherSuites: nil,
WantClientError: nil,
WantServerError: nil,
},
{
Name: "Invalid CipherSuite",
ClientCipherSuites: []CipherSuiteID{0x00},
ServerCipherSuites: []CipherSuiteID{0x00},
WantClientError: &invalidCipherSuite{0x00},
WantServerError: &invalidCipherSuite{0x00},
},
{
Name: "Valid CipherSuites specified",
ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
WantClientError: nil,
WantServerError: nil,
WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
},
{
Name: "CipherSuites mismatch",
ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
WantClientError: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
WantServerError: errCipherSuiteNoIntersection,
},
{
Name: "Valid CipherSuites CCM specified",
ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM},
ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM},
WantClientError: nil,
WantServerError: nil,
WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM,
},
{
Name: "Valid CipherSuites CCM-8 specified",
ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8},
ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8},
WantClientError: nil,
WantServerError: nil,
WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8,
},
{
Name: "Server supports subset of client suites",
ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
WantClientError: nil,
WantServerError: nil,
WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
},
} {
test := test
t.Run(test.Name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ca, cb := dpipe.Pipe()
type result struct {
c *Conn
err error
}
c := make(chan result)
go func() {
client, err := testClient(ctx, ca, &Config{CipherSuites: test.ClientCipherSuites}, true)
c <- result{client, err}
}()
server, err := testServer(ctx, cb, &Config{CipherSuites: test.ServerCipherSuites}, true)
if err == nil {
defer func() {
_ = server.Close()
}()
}
if !errors.Is(err, test.WantServerError) {
t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
}
res := <-c
if res.err == nil {
_ = server.Close()
}
if !errors.Is(res.err, test.WantClientError) {
t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
}
if test.WantSelectedCipherSuite != 0x00 && res.c.state.cipherSuite.ID() != test.WantSelectedCipherSuite {
t.Errorf("TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s': expected(%v) actual(%v)", test.Name, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID())
}
})
}
}
func TestCertificateAndPSKServer(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
for _, test := range []struct {
Name string
ClientPSK bool
}{
{
Name: "Client uses PKI",
ClientPSK: false,
},
{
Name: "Client uses PSK",
ClientPSK: true,
},
} {
test := test
t.Run(test.Name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ca, cb := dpipe.Pipe()
type result struct {
c *Conn
err error
}
c := make(chan result)
go func() {
config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}}
if test.ClientPSK {
config.PSK = func([]byte) ([]byte, error) {
return []byte{0x00, 0x01, 0x02}, nil
}
config.PSKIdentityHint = []byte{0x00}
config.CipherSuites = []CipherSuiteID{TLS_PSK_WITH_AES_128_GCM_SHA256}
}
client, err := testClient(ctx, ca, config, false)
c <- result{client, err}
}()
config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_PSK_WITH_AES_128_GCM_SHA256},
PSK: func([]byte) ([]byte, error) {
return []byte{0x00, 0x01, 0x02}, nil
},
}
server, err := testServer(ctx, cb, config, true)
if err == nil {
defer func() {
_ = server.Close()
}()
} else {
t.Errorf("TestCertificateAndPSKServer: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, err)
}
res := <-c
if res.err == nil {
_ = server.Close()
} else {
t.Errorf("TestCertificateAndPSKServer: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, res.err)
}
})
}
}
func TestPSKConfiguration(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
for _, test := range []struct {
Name string
ClientHasCertificate bool
ServerHasCertificate bool
ClientPSK PSKCallback
ServerPSK PSKCallback
ClientPSKIdentity []byte
ServerPSKIdentity []byte
WantClientError error
WantServerError error
}{
{
Name: "PSK and no certificate specified",
ClientHasCertificate: false,
ServerHasCertificate: false,
ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSKIdentity: []byte{0x00},
ServerPSKIdentity: []byte{0x00},
WantClientError: errNoAvailablePSKCipherSuite,
WantServerError: errNoAvailablePSKCipherSuite,
},
{
Name: "PSK and certificate specified",
ClientHasCertificate: true,
ServerHasCertificate: true,
ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSKIdentity: []byte{0x00},
ServerPSKIdentity: []byte{0x00},
WantClientError: errNoAvailablePSKCipherSuite,
WantServerError: errNoAvailablePSKCipherSuite,
},
{
Name: "PSK and no identity specified",
ClientHasCertificate: false,
ServerHasCertificate: false,
ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSKIdentity: nil,
ServerPSKIdentity: nil,
WantClientError: errPSKAndIdentityMustBeSetForClient,
WantServerError: errNoAvailablePSKCipherSuite,
},
{
Name: "No PSK and identity specified",
ClientHasCertificate: false,
ServerHasCertificate: false,
ClientPSK: nil,
ServerPSK: nil,
ClientPSKIdentity: []byte{0x00},
ServerPSKIdentity: []byte{0x00},
WantClientError: errIdentityNoPSK,
WantServerError: errIdentityNoPSK,
},
} {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ca, cb := dpipe.Pipe()
type result struct {
c *Conn
err error
}
c := make(chan result)
go func() {
client, err := testClient(ctx, ca, &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate)
c <- result{client, err}
}()
_, err := testServer(ctx, cb, &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate)
if err != nil || test.WantServerError != nil {
if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) {
t.Fatalf("TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
}
}
res := <-c
if res.err != nil || test.WantClientError != nil {
if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) {
t.Fatalf("TestPSKConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
}
}
}
}
func TestServerTimeout(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
cookie := make([]byte, 20)
_, err := rand.Read(cookie)
if err != nil {
t.Fatal(err)
}
var rand [28]byte
random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}
cipherSuites := []CipherSuite{
&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{},
&ciphersuite.TLSEcdheRsaWithAes128GcmSha256{},
}
extensions := []extension.Extension{
&extension.SupportedSignatureAlgorithms{
SignatureHashAlgorithms: []signaturehash.Algorithm{
{Hash: hash.SHA256, Signature: signature.ECDSA},
{Hash: hash.SHA384, Signature: signature.ECDSA},
{Hash: hash.SHA512, Signature: signature.ECDSA},
{Hash: hash.SHA256, Signature: signature.RSA},
{Hash: hash.SHA384, Signature: signature.RSA},
{Hash: hash.SHA512, Signature: signature.RSA},
},
},
&extension.SupportedEllipticCurves{
EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
},
&extension.SupportedPointFormats{
PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
},
}
record := &recordlayer.RecordLayer{
Header: recordlayer.Header{
SequenceNumber: 0,
Version: protocol.Version1_2,
},
Content: &handshake.Handshake{
// sequenceNumber and messageSequence line up, may need to be re-evaluated
Header: handshake.Header{
MessageSequence: 0,
},
Message: &handshake.MessageClientHello{
Version: protocol.Version1_2,
Cookie: cookie,
Random: random,
CipherSuiteIDs: cipherSuiteIDs(cipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
},
},
}
packet, err := record.Marshal()
if err != nil {
t.Fatal(err)
}
ca, cb := dpipe.Pipe()
defer func() {
err := ca.Close()
if err != nil {
t.Fatal(err)
}
}()
// Client reader
caReadChan := make(chan []byte, 1000)
go func() {
for {
data := make([]byte, 8192)
n, err := ca.Read(data)
if err != nil {
return
}
caReadChan <- data[:n]
}
}()
// Start sending ClientHello packets until server responds with first packet
go func() {
for {
select {
case <-time.After(10 * time.Millisecond):
_, err := ca.Write(packet)
if err != nil {
return
}
case <-caReadChan:
// Once we receive the first reply from the server, stop
return
}
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
FlightInterval: 100 * time.Millisecond,
}
_, serverErr := testServer(ctx, cb, config, true)
if netErr, ok := serverErr.(net.Error); !ok || !netErr.Timeout() {
t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr)
}
// Wait a little longer to ensure no additional messages have been sent by the server
time.Sleep(300 * time.Millisecond)
select {
case msg := <-caReadChan:
t.Fatalf("Expected no additional messages from server, got: %+v", msg)
default:
}
}
func TestProtocolVersionValidation(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
cookie := make([]byte, 20)
if _, err := rand.Read(cookie); err != nil {
t.Fatal(err)
}
var rand [28]byte
random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}
localKeypair, err := elliptic.GenerateKeypair(elliptic.X25519)
if err != nil {
t.Fatal(err)
}
config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
FlightInterval: 100 * time.Millisecond,
}
t.Run("Server", func(t *testing.T) {
serverCases := map[string]struct {
records []*recordlayer.RecordLayer
}{
"ClientHelloVersion": {
records: []*recordlayer.RecordLayer{
{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: &handshake.Handshake{
Message: &handshake.MessageClientHello{
Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade
Cookie: cookie,
Random: random,
CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())},
CompressionMethods: defaultCompressionMethods(),
},
},
},
},
},
"SecondsClientHelloVersion": {
records: []*recordlayer.RecordLayer{
{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: &handshake.Handshake{
Message: &handshake.MessageClientHello{
Version: protocol.Version1_2,
Cookie: cookie,
Random: random,
CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())},
CompressionMethods: defaultCompressionMethods(),
},
},
},
{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: 1,
},
Content: &handshake.Handshake{
Header: handshake.Header{
MessageSequence: 1,
},
Message: &handshake.MessageClientHello{
Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade
Cookie: cookie,
Random: random,
CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())},
CompressionMethods: defaultCompressionMethods(),
},
},
},
},
},
}
for name, c := range serverCases {
c := c
t.Run(name, func(t *testing.T) {
ca, cb := dpipe.Pipe()
defer func() {
err := ca.Close()
if err != nil {
t.Error(err)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()
go func() {
defer wg.Done()
if _, err := testServer(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) {
t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
}
}()
time.Sleep(50 * time.Millisecond)
resp := make([]byte, 1024)
for _, record := range c.records {
packet, err := record.Marshal()
if err != nil {
t.Fatal(err)
}
if _, werr := ca.Write(packet); werr != nil {
t.Fatal(werr)
}
n, rerr := ca.Read(resp[:cap(resp)])
if rerr != nil {
t.Fatal(rerr)
}
resp = resp[:n]
}
h := &recordlayer.Header{}
if err := h.Unmarshal(resp); err != nil {
t.Fatal("Failed to unmarshal response")
}
if h.ContentType != protocol.ContentTypeAlert {
t.Errorf("Peer must return alert to unsupported protocol version")
}
})
}
})
t.Run("Client", func(t *testing.T) {
clientCases := map[string]struct {
records []*recordlayer.RecordLayer
}{
"ServerHelloVersion": {
records: []*recordlayer.RecordLayer{
{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: &handshake.Handshake{
Message: &handshake.MessageHelloVerifyRequest{
Version: protocol.Version1_2,
Cookie: cookie,
},
},
},
{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: 1,
},
Content: &handshake.Handshake{
Header: handshake.Header{
MessageSequence: 1,
},
Message: &handshake.MessageServerHello{
Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade
Random: random,
CipherSuiteID: func() *uint16 { id := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256); return &id }(),
CompressionMethod: defaultCompressionMethods()[0],
},
},
},
{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: 2,
},
Content: &handshake.Handshake{
Header: handshake.Header{
MessageSequence: 2,
},
Message: &handshake.MessageCertificate{},
},
},
{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: 3,
},
Content: &handshake.Handshake{
Header: handshake.Header{
MessageSequence: 3,
},
Message: &handshake.MessageServerKeyExchange{
EllipticCurveType: elliptic.CurveTypeNamedCurve,
NamedCurve: elliptic.X25519,
PublicKey: localKeypair.PublicKey,
HashAlgorithm: hash.SHA256,
SignatureAlgorithm: signature.ECDSA,
Signature: make([]byte, 64),
},
},
},
{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: 4,
},
Content: &handshake.Handshake{
Header: handshake.Header{
MessageSequence: 4,
},
Message: &handshake.MessageServerHelloDone{},
},
},
},
},
}
for name, c := range clientCases {
c := c
t.Run(name, func(t *testing.T) {
ca, cb := dpipe.Pipe()
defer func() {
err := ca.Close()
if err != nil {
t.Error(err)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()
go func() {
defer wg.Done()
if _, err := testClient(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) {
t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
}
}()
time.Sleep(50 * time.Millisecond)
for _, record := range c.records {
if _, err := ca.Read(make([]byte, 1024)); err != nil {
t.Fatal(err)
}
packet, err := record.Marshal()
if err != nil {
t.Fatal(err)
}
if _, err := ca.Write(packet); err != nil {
t.Fatal(err)
}
}
resp := make([]byte, 1024)
n, err := ca.Read(resp)
if err != nil {
t.Fatal(err)
}
resp = resp[:n]
h := &recordlayer.Header{}
if err := h.Unmarshal(resp); err != nil {
t.Fatal("Failed to unmarshal response")
}
if h.ContentType != protocol.ContentTypeAlert {
t.Errorf("Peer must return alert to unsupported protocol version")
}
})
}
})
}
func TestMultipleHelloVerifyRequest(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
cookies := [][]byte{
// first clientHello contains an empty cookie
{},
}
var packets [][]byte
for i := 0; i < 2; i++ {
cookie := make([]byte, 20)
if _, err := rand.Read(cookie); err != nil {
t.Fatal(err)
}
cookies = append(cookies, cookie)
record := &recordlayer.RecordLayer{
Header: recordlayer.Header{
SequenceNumber: uint64(i),
Version: protocol.Version1_2,
},
Content: &handshake.Handshake{
Header: handshake.Header{
MessageSequence: uint16(i),
},
Message: &handshake.MessageHelloVerifyRequest{
Version: protocol.Version1_2,
Cookie: cookie,
},
},
}
packet, err := record.Marshal()
if err != nil {
t.Fatal(err)
}
packets = append(packets, packet)
}
ca, cb := dpipe.Pipe()
defer func() {
err := ca.Close()
if err != nil {
t.Error(err)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()
go func() {
defer wg.Done()
_, _ = testClient(ctx, ca, &Config{}, false)
}()
for i, cookie := range cookies {
// read client hello
resp := make([]byte, 1024)
n, err := cb.Read(resp)
if err != nil {
t.Fatal(err)
}
record := &recordlayer.RecordLayer{}
if err := record.Unmarshal(resp[:n]); err != nil {
t.Fatal(err)
}
clientHello := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
if !bytes.Equal(clientHello.Cookie, cookie) {
t.Fatalf("Wrong cookie, expected: %x, got: %x", clientHello.Cookie, cookie)
}
if len(packets) <= i {
break
}
// write hello verify request
if _, err := cb.Write(packets[i]); err != nil {
t.Fatal(err)
}
}
cancel()
}
// Assert that a DTLS Server always responds with RenegotiationInfo if
// a ClientHello contained that extension or not
func TestRenegotationInfo(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(10 * time.Second)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
resp := make([]byte, 1024)
for _, testCase := range []struct {
Name string
SendRenegotiationInfo bool
}{
{
"Include RenegotiationInfo",
true,
},
{
"No RenegotiationInfo",
false,
},
} {
test := testCase
t.Run(test.Name, func(t *testing.T) {
sendClientHello := func(cookie []byte, ca net.Conn, sequenceNumber uint64) {
extensions := []extension.Extension{}
if test.SendRenegotiationInfo {
extensions = append(extensions, &extension.RenegotiationInfo{
RenegotiatedConnection: 0,
})
}
packet, err := (&recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: sequenceNumber,
},
Content: &handshake.Handshake{
Header: handshake.Header{
MessageSequence: uint16(sequenceNumber),
},
Message: &handshake.MessageClientHello{
Version: protocol.Version1_2,
Cookie: cookie,
CipherSuiteIDs: cipherSuiteIDs(defaultCipherSuites()),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
},
},
}).Marshal()
if err != nil {
t.Fatal(err)
}
if _, err = ca.Write(packet); err != nil {
t.Fatal(err)
}
}
ca, cb := dpipe.Pipe()
defer func() {
if err := ca.Close(); err != nil {
t.Error(err)
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) {
t.Error(err)
}
}()
time.Sleep(50 * time.Millisecond)
sendClientHello([]byte{}, ca, 0)
n, err := ca.Read(resp)
if err != nil {
t.Fatal(err)
}
r := &recordlayer.RecordLayer{}
if err = r.Unmarshal(resp[:n]); err != nil {
t.Fatal(err)
}
helloVerifyRequest := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
sendClientHello(helloVerifyRequest.Cookie, ca, 1)
if n, err = ca.Read(resp); err != nil {
t.Fatal(err)
}
messages, err := recordlayer.UnpackDatagram(resp[:n])
if err != nil {
t.Fatal(err)
}
if err := r.Unmarshal(messages[0]); err != nil {
t.Fatal(err)
}
serverHello := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
gotNegotationInfo := false
for _, v := range serverHello.Extensions {
if _, ok := v.(*extension.RenegotiationInfo); ok {
gotNegotationInfo = true
}
}
if !gotNegotationInfo {
t.Fatalf("Received ServerHello without RenegotiationInfo")
}
})
}
}