diff --git a/counter_sample.go b/counter_sample.go index 31df19b..2b5084f 100644 --- a/counter_sample.go +++ b/counter_sample.go @@ -48,7 +48,7 @@ func (s *CounterSample) GetRecords() []Record { return s.Records } -func decodeCounterSample(r io.ReadSeeker) (Sample, error) { +func decodeCounterSample(r io.ReadSeeker, format uint32) (Sample, error) { s := &CounterSample{} var err error @@ -58,24 +58,43 @@ func decodeCounterSample(r io.ReadSeeker) (Sample, error) { return nil, err } - err = binary.Read(r, binary.BigEndian, &s.SourceIdType) - if err != nil { - return nil, err - } + switch format { + case TypeCounterSample: + err = binary.Read(r, binary.BigEndian, &s.SourceIdType) + if err != nil { + return nil, err + } - var srcIdIndexVal [3]byte - n, err := r.Read(srcIdIndexVal[:]) - if err != nil { - return nil, err - } + var srcIdIndexVal [3]byte + n, err := r.Read(srcIdIndexVal[:]) + if err != nil { + return nil, err + } - if n != 3 { - return nil, errors.New("sflow: counter sample decoding error") - } + if n != 3 { + return nil, errors.New("sflow: counter sample decoding error") + } + + s.SourceIdIndexVal = uint32(srcIdIndexVal[2]) | + uint32(srcIdIndexVal[1])<<8 | + uint32(srcIdIndexVal[0])<<16 + + case TypeExpandedCounterSample: + var sourceIdType uint32 + err = binary.Read(r, binary.BigEndian, &sourceIdType) + if err != nil { + return nil, err + } + s.SourceIdType = byte(sourceIdType) + + err = binary.Read(r, binary.BigEndian, &s.SourceIdIndexVal) + if err != nil { + return nil, err + } - s.SourceIdIndexVal = uint32(srcIdIndexVal[2]) | - uint32(srcIdIndexVal[1])<<8 | - uint32(srcIdIndexVal[0])<<16 + default: + return nil, ErrUnknownSampleType + } err = binary.Read(r, binary.BigEndian, &s.numRecords) if err != nil { diff --git a/counter_sample_encode_test.go b/counter_sample_encode_test.go index 6befa51..5d42bed 100644 --- a/counter_sample_encode_test.go +++ b/counter_sample_encode_test.go @@ -48,7 +48,7 @@ func TestDecodeEncodeAndDecodeCounterSample(t *testing.T) { buf.Read(skip[:]) // bytes.Buffer is not an io.ReadSeeker. bytes.Reader is. - decodedSample, err := decodeCounterSample(bytes.NewReader(buf.Bytes())) + decodedSample, err := decodeCounterSample(bytes.NewReader(buf.Bytes()), TypeCounterSample) if err != nil { t.Fatal(err) } diff --git a/sample.go b/sample.go index 140b788..d425fb0 100644 --- a/sample.go +++ b/sample.go @@ -39,7 +39,10 @@ func decodeSample(r io.ReadSeeker) (Sample, error) { switch format { case TypeCounterSample: - return decodeCounterSample(r) + return decodeCounterSample(r, format) + + case TypeExpandedCounterSample: + return decodeCounterSample(r, format) case TypeFlowSample: return decodeFlowSample(r)