package structs import ( "bufio" "bytes" "encoding" "fmt" "io" "encoding/binary" "github.com/tuneinsight/lattigo/v4/utils/buffer" ) type binarySerializer interface { encoding.BinaryMarshaler encoding.BinaryUnmarshaler io.WriterTo io.ReaderFrom // Encoder // Decoder } type Vector[T any] []T // CopyNew creates a copy of the oject. func (v Vector[T]) CopyNew() *Vector[T] { if c, isCopiable := any(new(T)).(CopyNewer[T]); !isCopiable { panic(fmt.Errorf("vector component of type %T does not comply to %T", new(T), c)) } vcpy := Vector[T](make([]T, len(v))) for i, c := range v { vcpy[i] = *any(&c).(CopyNewer[T]).CopyNew() } return &vcpy } // WriteTo writes the object on an io.Writer. // To ensure optimal efficiency and minimal allocations, the user is encouraged // to provide a struct implementing the interface buffer.Writer, which defines // a subset of the method of the bufio.Writer. // If w is not compliant to the buffer.Writer interface, it will be wrapped in // a new bufio.Writer. // For additional information, see lattigo/utils/buffer/writer.go. func (v *Vector[T]) WriteTo(w io.Writer) (n int64, err error) { if w, isWritable := any(new(T)).(io.WriterTo); !isWritable { return 0, fmt.Errorf("vector component of type %T does not comply to %T", new(T), w) } switch w := w.(type) { case buffer.Writer: vval := *v var inc int if inc, err = buffer.WriteInt(w, len(vval)); err != nil { return int64(inc), err } n += int64(inc) for _, c := range vval { inc, err := any(&c).(io.WriterTo).WriteTo(w) n += inc if err != nil { return n, err } } return n, w.Flush() default: return v.WriteTo(bufio.NewWriter(w)) } } // ReadFrom reads on the object from an io.Writer. // To ensure optimal efficiency and minimal allocations, the user is encouraged // to provide a struct implementing the interface buffer.Reader, which defines // a subset of the method of the bufio.Reader. // If r is not compliant to the buffer.Reader interface, it will be wrapped in // a new bufio.Reader. // For additional information, see lattigo/utils/buffer/reader.go. func (v *Vector[T]) ReadFrom(r io.Reader) (n int64, err error) { if r, isReadable := any(new(T)).(io.ReaderFrom); !isReadable { return 0, fmt.Errorf("vector component of type %T does not comply to %T", new(T), r) } // TODO: when has access to Reader's buffer, call Decode ? switch r := r.(type) { case buffer.Reader: var size int var inc int // TODO int64 in buffer package ? if inc, err = buffer.ReadInt(r, &size); err != nil { return int64(inc), fmt.Errorf("cannot read vector size: %w", err) } n += int64(inc) if cap(*v) < size { *v = make([]T, size) } *v = (*v)[:size] for i := range *v { inc, err := any(&(*v)[i]).(io.ReaderFrom).ReadFrom(r) n += inc if err != nil { return n, err } } return int64(n), nil default: return v.ReadFrom(bufio.NewReader(r)) } } // BinarySize returns the size in bytes of the object // when encoded using Encode. func (v Vector[T]) BinarySize() (size int) { if s, isSizable := any(new(T)).(BinarySizer); !isSizable { panic(fmt.Errorf("vector component of type %T does not comply to %T", new(T), s)) } size += 8 for _, c := range v { size += any(&c).(BinarySizer).BinarySize() } return } // Encode encodes the object into a binary form on a preallocated slice of bytes // and returns the number of bytes written. func (v *Vector[T]) Encode(b []byte) (n int, err error) { if e, isEncodable := any(new(T)).(Encoder); !isEncodable { panic(fmt.Errorf("vector component of type %T does not comply to %T", new(T), e)) } vval := *v binary.LittleEndian.PutUint64(b[n:], uint64(len(vval))) n += 8 var inc int for _, c := range vval { if inc, err := any(&c).(Encoder).Encode(b[n:]); err != nil { return n + inc, err } n += inc } return } // Decode decodes a slice of bytes generated by Encode // on the object and returns the number of bytes read. func (v *Vector[T]) Decode(p []byte) (n int, err error) { if d, isDecodable := any(new(T)).(Decoder); !isDecodable { panic(fmt.Errorf("vector component of type %T does not comply to %T", new(T), d)) } size := int(binary.LittleEndian.Uint64(p[n:])) // TODO: there is a bug here but it is not caught by the tests. n += 8 if cap(*v) < size { *v = make([]T, size) } *v = (*v)[:size] var inc int for i := range *v { if inc, err = any(&(*v)[i]).(Decoder).Decode(p[n:]); err != nil { return n + inc, err } n += inc } return } // MarshalBinary encodes the object into a binary form on a newly allocated slice of bytes. func (v *Vector[T]) MarshalBinary() (p []byte, err error) { buf := bytes.NewBuffer([]byte{}) _, err = v.WriteTo(buf) return buf.Bytes(), err } // UnmarshalBinary decodes a slice of bytes generated by // MarshalBinary or WriteTo on the object. func (v *Vector[T]) UnmarshalBinary(p []byte) (err error) { _, err = v.ReadFrom(bytes.NewBuffer(p)) return }