package dtls import ( "bytes" "crypto/tls" "errors" "fmt" "net" "sync" "testing" "time" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/transport/test" ) var errMessageMissmatch = errors.New("messages missmatch") func TestResumeClient(t *testing.T) { DoTestResume(t, Client, Server) } func TestResumeServer(t *testing.T) { DoTestResume(t, Server, Client) } func fatal(t *testing.T, errChan chan error, err error) { close(errChan) t.Fatal(err) } func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Conn, error)) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() certificate, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } // Generate connections localConn1, rc1 := net.Pipe() localConn2, rc2 := net.Pipe() remoteConn := &backupConn{curr: rc1, next: rc2} // Launch remote in another goroutine errChan := make(chan error, 1) defer func() { err = <-errChan if err != nil { t.Fatal(err) } }() config := &Config{ Certificates: []tls.Certificate{certificate}, InsecureSkipVerify: true, ExtendedMasterSecret: RequireExtendedMasterSecret, } go func() { var remote *Conn var errR error remote, errR = newRemote(remoteConn, config) if errR != nil { errChan <- errR } // Loop of read write for i := 0; i < 2; i++ { recv := make([]byte, 1024) var n int n, errR = remote.Read(recv) if errR != nil { errChan <- errR } if _, errR = remote.Write(recv[:n]); errR != nil { errChan <- errR } } errChan <- nil }() var local *Conn local, err = newLocal(localConn1, config) if err != nil { fatal(t, errChan, err) } defer func() { _ = local.Close() }() // Test write and read message := []byte("Hello") if _, err = local.Write(message); err != nil { fatal(t, errChan, err) } recv := make([]byte, 1024) var n int n, err = local.Read(recv) if err != nil { fatal(t, errChan, err) } if !bytes.Equal(message, recv[:n]) { fatal(t, errChan, fmt.Errorf("%w: %s != %s", errMessageMissmatch, message, recv[:n])) } if err = localConn1.Close(); err != nil { fatal(t, errChan, err) } // Serialize and deserialize state state := local.ConnectionState() var b []byte b, err = state.MarshalBinary() if err != nil { fatal(t, errChan, err) } deserialized := &State{} if err = deserialized.UnmarshalBinary(b); err != nil { fatal(t, errChan, err) } // Resume dtls connection var resumed net.Conn resumed, err = Resume(deserialized, localConn2, config) if err != nil { fatal(t, errChan, err) } defer func() { _ = resumed.Close() }() // Test write and read on resumed connection if _, err = resumed.Write(message); err != nil { fatal(t, errChan, err) } recv = make([]byte, 1024) n, err = resumed.Read(recv) if err != nil { fatal(t, errChan, err) } if !bytes.Equal(message, recv[:n]) { fatal(t, errChan, fmt.Errorf("%w: %s != %s", errMessageMissmatch, message, recv[:n])) } } type backupConn struct { curr net.Conn next net.Conn mux sync.Mutex } func (b *backupConn) Read(data []byte) (n int, err error) { n, err = b.curr.Read(data) if err != nil && b.next != nil { b.mux.Lock() b.curr = b.next b.next = nil b.mux.Unlock() return b.Read(data) } return n, err } func (b *backupConn) Write(data []byte) (n int, err error) { n, err = b.curr.Write(data) if err != nil && b.next != nil { b.mux.Lock() b.curr = b.next b.next = nil b.mux.Unlock() return b.Write(data) } return n, err } func (b *backupConn) Close() error { return nil } func (b *backupConn) LocalAddr() net.Addr { return nil } func (b *backupConn) RemoteAddr() net.Addr { return nil } func (b *backupConn) SetDeadline(t time.Time) error { return nil } func (b *backupConn) SetReadDeadline(t time.Time) error { return nil } func (b *backupConn) SetWriteDeadline(t time.Time) error { return nil }