package buffer import ( "bytes" "encoding" "fmt" "io" "reflect" "testing" "unsafe" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "github.com/tuneinsight/lattigo/v6/utils" ) // binarySerializer is a testing interface for byte encoding and decoding. type binarySerializer interface { BinarySize() int io.WriterTo io.ReaderFrom encoding.BinaryMarshaler encoding.BinaryUnmarshaler } // RequireSerializerCorrect tests that: // - input and output implement TestInterface // - input.WriteTo(io.Writer) writes a number of bytes on the writer equal to the number of bytes generated by input.MarshalBinary() // - input.WriteTo buffered bytes are equal to the bytes generated by input.MarshalBinary() // - output.ReadFrom(io.Reader) reads a number of bytes on the reader equal to the number of bytes written using input.WriteTo(io.Writer) // - applies require.Equalf between the original and reconstructed object for // - all the above WriteTo, ReadFrom, MarhsalBinary and UnmarshalBinary do not return an error func RequireSerializerCorrect(t *testing.T, input binarySerializer) { // Allocates a new object of the underlying type of input output := reflect.New(reflect.TypeOf(input).Elem()).Elem().Addr().Interface().(binarySerializer) data := []byte{} buf := bytes.NewBuffer(data) // Compliant to io.Writer and io.Reader // Check io.Writer bytesWritten, err := input.WriteTo(buf) require.NoError(t, err) require.Equal(t, int(bytesWritten), input.BinarySize(), fmt.Errorf("invalid size: %T.WriteTo #bytes written = %d != %T.BinarySize = %d", input, bytesWritten, input, input.BinarySize())) // Checks that #bytes written = len(buffer) require.Equal(t, len(buf.Bytes()), int(bytesWritten), fmt.Errorf("invalid size: %T.WriteTo len(buf.Bytes()) = %d != %T.WriteTo #bytes written = %d", input, len(buf.Bytes()), input, bytesWritten)) // Check encoding.BinaryMarshaler data2, err := input.MarshalBinary() require.NoError(t, err) // Check that #bytes written with io.Writer = #bytes generates by encoding.BinaryMarshaler require.Equal(t, len(data2), int(bytesWritten), fmt.Errorf("invalid size: %T.MarshalBinary #bytes generated = %d != %T.WriteTo #bytes written = %d", input, len(data2), input, bytesWritten)) // Check that bytes written with io.Writer = bytes generates by encoding.BinaryMarshaler require.True(t, bytes.Equal(buf.Bytes(), data2), fmt.Errorf("invalid encoding: %T.WriteTo buf.Bytes() != %T.MarshalBinary bytes generated", input, input)) // Check io.Reader bytesRead, err := output.ReadFrom(buf) require.NoError(t, err) // Check that #bytes read with io.Reader = #bytes written with io.Writer require.Equal(t, bytesRead, bytesWritten, fmt.Errorf("invalid encoding: %T.ReadFrom #bytes read = %d != %T.WriteTo #bytes written = %d", input, bytesRead, input, bytesWritten)) // Deep equal output = input require.True(t, cmp.Equal(input, output)) // Check encoding.BinaryUnmarshaler output = reflect.New(reflect.TypeOf(input).Elem()).Elem().Addr().Interface().(binarySerializer) require.NoError(t, output.UnmarshalBinary(data2)) // Deep equal output = input require.True(t, cmp.Equal(input, output)) } // EqualAsUint64 casts &T to an *uint64 and performs a comparison. // User must ensure that T can be stored in an uint64. func EqualAsUint64[T any](a, b T) bool { /* #nosec G103 -- behavior and consequences well understood, pointer type cast */ return *(*uint64)(unsafe.Pointer(&a)) == *(*uint64)(unsafe.Pointer(&b)) } // EqualAsUint64Slice casts &[]T into *[]uint64 and performs a comparison. // User must ensure that T can be stored in an uint64. func EqualAsUint64Slice[T any](a, b []T) bool { /* #nosec G103 -- behavior and consequences well understood, pointer type cast */ aU64 := *(*[]uint64)(unsafe.Pointer(&a)) /* #nosec G103 -- behavior and consequences well understood, pointer type cast */ bU64 := *(*[]uint64)(unsafe.Pointer(&b)) if len(aU64) != len(bU64) { return false } if utils.Alias1D(aU64, bU64) { return true } for i := range aU64 { if aU64[i] != bU64[i] { return false } } return true } // EqualAsUint32Slice casts &[]T into *[]uint32 and performs a comparison. // User must ensure that T can be stored in an uint32. func EqualAsUint32Slice[T any](a, b []T) bool { /* #nosec G103 -- behavior and consequences well understood, pointer type cast */ aU32 := *(*[]uint32)(unsafe.Pointer(&a)) /* #nosec G103 -- behavior and consequences well understood, pointer type cast */ bU32 := *(*[]uint32)(unsafe.Pointer(&b)) if len(aU32) != len(bU32) { return false } if utils.Alias1D(aU32, bU32) { return true } for i := range aU32 { if aU32[i] != bU32[i] { return false } } return true } // EqualAsUint16Slice casts &[]T into *[]uint16 and performs a comparison. // User must ensure that T can be stored in an uint16. func EqualAsUint16Slice[T any](a, b []T) bool { /* #nosec G103 -- behavior and consequences well understood, pointer type cast */ aU16 := *(*[]uint16)(unsafe.Pointer(&a)) /* #nosec G103 -- behavior and consequences well understood, pointer type cast */ bU16 := *(*[]uint16)(unsafe.Pointer(&b)) if len(aU16) != len(bU16) { return false } if utils.Alias1D(aU16, bU16) { return true } for i := range aU16 { if aU16[i] != bU16[i] { return false } } return true } // EqualAsUint8Slice casts &[]T into *[]uint8 and performs a comparison. // User must ensure that T can be stored in an uint8. func EqualAsUint8Slice[T any](a, b []T) bool { /* #nosec G103 -- behavior and consequences well understood, pointer type cast */ aU8 := *(*[]uint8)(unsafe.Pointer(&a)) /* #nosec G103 -- behavior and consequences well understood, pointer type cast */ bU8 := *(*[]uint8)(unsafe.Pointer(&b)) if len(aU8) != len(bU8) { return false } if utils.Alias1D(aU8, bU8) { return true } for i := range aU8 { if aU8[i] != bU8[i] { return false } } return true }