Source file src/internal/zstd/zstd.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package zstd provides a decompressor for zstd streams,
     6  // described in RFC 8878. It does not support dictionaries.
     7  package zstd
     8  
     9  import (
    10  	"encoding/binary"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  )
    15  
    16  // fuzzing is a fuzzer hook set to true when fuzzing.
    17  // This is used to reject cases where we don't match zstd.
    18  var fuzzing = false
    19  
    20  // Reader implements [io.Reader] to read a zstd compressed stream.
    21  type Reader struct {
    22  	// The underlying Reader.
    23  	r io.Reader
    24  
    25  	// Whether we have read the frame header.
    26  	// This is of interest when buffer is empty.
    27  	// If true we expect to see a new block.
    28  	sawFrameHeader bool
    29  
    30  	// Whether the current frame expects a checksum.
    31  	hasChecksum bool
    32  
    33  	// Whether we have read at least one frame.
    34  	readOneFrame bool
    35  
    36  	// True if the frame size is not known.
    37  	frameSizeUnknown bool
    38  
    39  	// The number of uncompressed bytes remaining in the current frame.
    40  	// If frameSizeUnknown is true, this is not valid.
    41  	remainingFrameSize uint64
    42  
    43  	// The number of bytes read from r up to the start of the current
    44  	// block, for error reporting.
    45  	blockOffset int64
    46  
    47  	// Buffered decompressed data.
    48  	buffer []byte
    49  	// Current read offset in buffer.
    50  	off int
    51  
    52  	// The current repeated offsets.
    53  	repeatedOffset1 uint32
    54  	repeatedOffset2 uint32
    55  	repeatedOffset3 uint32
    56  
    57  	// The current Huffman tree used for compressing literals.
    58  	huffmanTable     []uint16
    59  	huffmanTableBits int
    60  
    61  	// The window for back references.
    62  	window window
    63  
    64  	// A buffer available to hold a compressed block.
    65  	compressedBuf []byte
    66  
    67  	// A buffer for literals.
    68  	literals []byte
    69  
    70  	// Sequence decode FSE tables.
    71  	seqTables    [3][]fseBaselineEntry
    72  	seqTableBits [3]uint8
    73  
    74  	// Buffers for sequence decode FSE tables.
    75  	seqTableBuffers [3][]fseBaselineEntry
    76  
    77  	// Scratch space used for small reads, to avoid allocation.
    78  	scratch [16]byte
    79  
    80  	// A scratch table for reading an FSE. Only temporarily valid.
    81  	fseScratch []fseEntry
    82  
    83  	// For checksum computation.
    84  	checksum xxhash64
    85  }
    86  
    87  // NewReader creates a new Reader that decompresses data from the given reader.
    88  func NewReader(input io.Reader) *Reader {
    89  	r := new(Reader)
    90  	r.Reset(input)
    91  	return r
    92  }
    93  
    94  // Reset discards the current state and starts reading a new stream from r.
    95  // This permits reusing a Reader rather than allocating a new one.
    96  func (r *Reader) Reset(input io.Reader) {
    97  	r.r = input
    98  
    99  	// Several fields are preserved to avoid allocation.
   100  	// Others are always set before they are used.
   101  	r.sawFrameHeader = false
   102  	r.hasChecksum = false
   103  	r.readOneFrame = false
   104  	r.frameSizeUnknown = false
   105  	r.remainingFrameSize = 0
   106  	r.blockOffset = 0
   107  	r.buffer = r.buffer[:0]
   108  	r.off = 0
   109  	// repeatedOffset1
   110  	// repeatedOffset2
   111  	// repeatedOffset3
   112  	// huffmanTable
   113  	// huffmanTableBits
   114  	// window
   115  	// compressedBuf
   116  	// literals
   117  	// seqTables
   118  	// seqTableBits
   119  	// seqTableBuffers
   120  	// scratch
   121  	// fseScratch
   122  }
   123  
   124  // Read implements [io.Reader].
   125  func (r *Reader) Read(p []byte) (int, error) {
   126  	if err := r.refillIfNeeded(); err != nil {
   127  		return 0, err
   128  	}
   129  	n := copy(p, r.buffer[r.off:])
   130  	r.off += n
   131  	return n, nil
   132  }
   133  
   134  // ReadByte implements [io.ByteReader].
   135  func (r *Reader) ReadByte() (byte, error) {
   136  	if err := r.refillIfNeeded(); err != nil {
   137  		return 0, err
   138  	}
   139  	ret := r.buffer[r.off]
   140  	r.off++
   141  	return ret, nil
   142  }
   143  
   144  // refillIfNeeded reads the next block if necessary.
   145  func (r *Reader) refillIfNeeded() error {
   146  	for r.off >= len(r.buffer) {
   147  		if err := r.refill(); err != nil {
   148  			return err
   149  		}
   150  		r.off = 0
   151  	}
   152  	return nil
   153  }
   154  
   155  // refill reads and decompresses the next block.
   156  func (r *Reader) refill() error {
   157  	if !r.sawFrameHeader {
   158  		if err := r.readFrameHeader(); err != nil {
   159  			return err
   160  		}
   161  	}
   162  	return r.readBlock()
   163  }
   164  
   165  // readFrameHeader reads the frame header and prepares to read a block.
   166  func (r *Reader) readFrameHeader() error {
   167  retry:
   168  	relativeOffset := 0
   169  
   170  	// Read magic number. RFC 3.1.1.
   171  	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
   172  		// We require that the stream contains at least one frame.
   173  		if err == io.EOF && !r.readOneFrame {
   174  			err = io.ErrUnexpectedEOF
   175  		}
   176  		return r.wrapError(relativeOffset, err)
   177  	}
   178  
   179  	if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
   180  		if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
   181  			// This is a skippable frame.
   182  			r.blockOffset += int64(relativeOffset) + 4
   183  			if err := r.skipFrame(); err != nil {
   184  				return err
   185  			}
   186  			r.readOneFrame = true
   187  			goto retry
   188  		}
   189  
   190  		return r.makeError(relativeOffset, "invalid magic number")
   191  	}
   192  
   193  	relativeOffset += 4
   194  
   195  	// Read Frame_Header_Descriptor. RFC 3.1.1.1.1.
   196  	if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
   197  		return r.wrapNonEOFError(relativeOffset, err)
   198  	}
   199  	descriptor := r.scratch[0]
   200  
   201  	singleSegment := descriptor&(1<<5) != 0
   202  
   203  	fcsFieldSize := 1 << (descriptor >> 6)
   204  	if fcsFieldSize == 1 && !singleSegment {
   205  		fcsFieldSize = 0
   206  	}
   207  
   208  	var windowDescriptorSize int
   209  	if singleSegment {
   210  		windowDescriptorSize = 0
   211  	} else {
   212  		windowDescriptorSize = 1
   213  	}
   214  
   215  	if descriptor&(1<<3) != 0 {
   216  		return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
   217  	}
   218  
   219  	r.hasChecksum = descriptor&(1<<2) != 0
   220  	if r.hasChecksum {
   221  		r.checksum.reset()
   222  	}
   223  
   224  	// Dictionary_ID_Flag. RFC 3.1.1.1.1.6.
   225  	dictionaryIdSize := 0
   226  	if dictIdFlag := descriptor & 3; dictIdFlag != 0 {
   227  		dictionaryIdSize = 1 << (dictIdFlag - 1)
   228  	}
   229  
   230  	relativeOffset++
   231  
   232  	headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize
   233  
   234  	if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
   235  		return r.wrapNonEOFError(relativeOffset, err)
   236  	}
   237  
   238  	// Figure out the maximum amount of data we need to retain
   239  	// for backreferences.
   240  	var windowSize int
   241  	if !singleSegment {
   242  		// Window descriptor. RFC 3.1.1.1.2.
   243  		windowDescriptor := r.scratch[0]
   244  		exponent := uint64(windowDescriptor >> 3)
   245  		mantissa := uint64(windowDescriptor & 7)
   246  		windowLog := exponent + 10
   247  		windowBase := uint64(1) << windowLog
   248  		windowAdd := (windowBase / 8) * mantissa
   249  		windowSize = int(windowBase + windowAdd)
   250  
   251  		// Default zstd sets limits on the window size.
   252  		if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
   253  			return r.makeError(relativeOffset, "windowSize too large")
   254  		}
   255  	}
   256  
   257  	// Dictionary_ID. RFC 3.1.1.1.3.
   258  	if dictionaryIdSize != 0 {
   259  		dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize]
   260  		// Allow only zero Dictionary ID.
   261  		for _, b := range dictionaryId {
   262  			if b != 0 {
   263  				return r.makeError(relativeOffset, "dictionaries are not supported")
   264  			}
   265  		}
   266  	}
   267  
   268  	// Frame_Content_Size. RFC 3.1.1.1.4.
   269  	r.frameSizeUnknown = false
   270  	r.remainingFrameSize = 0
   271  	fb := r.scratch[windowDescriptorSize+dictionaryIdSize:]
   272  	switch fcsFieldSize {
   273  	case 0:
   274  		r.frameSizeUnknown = true
   275  	case 1:
   276  		r.remainingFrameSize = uint64(fb[0])
   277  	case 2:
   278  		r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
   279  	case 4:
   280  		r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
   281  	case 8:
   282  		r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
   283  	default:
   284  		panic("unreachable")
   285  	}
   286  
   287  	// RFC 3.1.1.1.2.
   288  	// When Single_Segment_Flag is set, Window_Descriptor is not present.
   289  	// In this case, Window_Size is Frame_Content_Size.
   290  	if singleSegment {
   291  		windowSize = int(r.remainingFrameSize)
   292  	}
   293  
   294  	// RFC 8878 3.1.1.1.1.2. permits us to set an 8M max on window size.
   295  	if windowSize > 8<<20 {
   296  		windowSize = 8 << 20
   297  	}
   298  
   299  	relativeOffset += headerSize
   300  
   301  	r.sawFrameHeader = true
   302  	r.readOneFrame = true
   303  	r.blockOffset += int64(relativeOffset)
   304  
   305  	// Prepare to read blocks from the frame.
   306  	r.repeatedOffset1 = 1
   307  	r.repeatedOffset2 = 4
   308  	r.repeatedOffset3 = 8
   309  	r.huffmanTableBits = 0
   310  	r.window.reset(windowSize)
   311  	r.seqTables[0] = nil
   312  	r.seqTables[1] = nil
   313  	r.seqTables[2] = nil
   314  
   315  	return nil
   316  }
   317  
   318  // skipFrame skips a skippable frame. RFC 3.1.2.
   319  func (r *Reader) skipFrame() error {
   320  	relativeOffset := 0
   321  
   322  	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
   323  		return r.wrapNonEOFError(relativeOffset, err)
   324  	}
   325  
   326  	relativeOffset += 4
   327  
   328  	size := binary.LittleEndian.Uint32(r.scratch[:4])
   329  	if size == 0 {
   330  		r.blockOffset += int64(relativeOffset)
   331  		return nil
   332  	}
   333  
   334  	if seeker, ok := r.r.(io.Seeker); ok {
   335  		r.blockOffset += int64(relativeOffset)
   336  		// Implementations of Seeker do not always detect invalid offsets,
   337  		// so check that the new offset is valid by comparing to the end.
   338  		prev, err := seeker.Seek(0, io.SeekCurrent)
   339  		if err != nil {
   340  			return r.wrapError(0, err)
   341  		}
   342  		end, err := seeker.Seek(0, io.SeekEnd)
   343  		if err != nil {
   344  			return r.wrapError(0, err)
   345  		}
   346  		if prev > end-int64(size) {
   347  			r.blockOffset += end - prev
   348  			return r.makeEOFError(0)
   349  		}
   350  
   351  		// The new offset is valid, so seek to it.
   352  		_, err = seeker.Seek(prev+int64(size), io.SeekStart)
   353  		if err != nil {
   354  			return r.wrapError(0, err)
   355  		}
   356  		r.blockOffset += int64(size)
   357  		return nil
   358  	}
   359  
   360  	var skip []byte
   361  	const chunk = 1 << 20 // 1M
   362  	for size >= chunk {
   363  		if len(skip) == 0 {
   364  			skip = make([]byte, chunk)
   365  		}
   366  		if _, err := io.ReadFull(r.r, skip); err != nil {
   367  			return r.wrapNonEOFError(relativeOffset, err)
   368  		}
   369  		relativeOffset += chunk
   370  		size -= chunk
   371  	}
   372  	if size > 0 {
   373  		if len(skip) == 0 {
   374  			skip = make([]byte, size)
   375  		}
   376  		if _, err := io.ReadFull(r.r, skip); err != nil {
   377  			return r.wrapNonEOFError(relativeOffset, err)
   378  		}
   379  		relativeOffset += int(size)
   380  	}
   381  
   382  	r.blockOffset += int64(relativeOffset)
   383  
   384  	return nil
   385  }
   386  
   387  // readBlock reads the next block from a frame.
   388  func (r *Reader) readBlock() error {
   389  	relativeOffset := 0
   390  
   391  	// Read Block_Header. RFC 3.1.1.2.
   392  	if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
   393  		return r.wrapNonEOFError(relativeOffset, err)
   394  	}
   395  
   396  	relativeOffset += 3
   397  
   398  	header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
   399  
   400  	lastBlock := header&1 != 0
   401  	blockType := (header >> 1) & 3
   402  	blockSize := int(header >> 3)
   403  
   404  	// Maximum block size is smaller of window size and 128K.
   405  	// We don't record the window size for a single segment frame,
   406  	// so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4.
   407  	if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) {
   408  		return r.makeError(relativeOffset, "block size too large")
   409  	}
   410  
   411  	// Handle different block types. RFC 3.1.1.2.2.
   412  	switch blockType {
   413  	case 0:
   414  		r.setBufferSize(blockSize)
   415  		if _, err := io.ReadFull(r.r, r.buffer); err != nil {
   416  			return r.wrapNonEOFError(relativeOffset, err)
   417  		}
   418  		relativeOffset += blockSize
   419  		r.blockOffset += int64(relativeOffset)
   420  	case 1:
   421  		r.setBufferSize(blockSize)
   422  		if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
   423  			return r.wrapNonEOFError(relativeOffset, err)
   424  		}
   425  		relativeOffset++
   426  		v := r.scratch[0]
   427  		for i := range r.buffer {
   428  			r.buffer[i] = v
   429  		}
   430  		r.blockOffset += int64(relativeOffset)
   431  	case 2:
   432  		r.blockOffset += int64(relativeOffset)
   433  		if err := r.compressedBlock(blockSize); err != nil {
   434  			return err
   435  		}
   436  		r.blockOffset += int64(blockSize)
   437  	case 3:
   438  		return r.makeError(relativeOffset, "invalid block type")
   439  	}
   440  
   441  	if !r.frameSizeUnknown {
   442  		if uint64(len(r.buffer)) > r.remainingFrameSize {
   443  			return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
   444  		}
   445  		r.remainingFrameSize -= uint64(len(r.buffer))
   446  	}
   447  
   448  	if r.hasChecksum {
   449  		r.checksum.update(r.buffer)
   450  	}
   451  
   452  	if !lastBlock {
   453  		r.window.save(r.buffer)
   454  	} else {
   455  		if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
   456  			return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
   457  		}
   458  		// Check for checksum at end of frame. RFC 3.1.1.
   459  		if r.hasChecksum {
   460  			if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
   461  				return r.wrapNonEOFError(0, err)
   462  			}
   463  
   464  			inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
   465  			dataChecksum := uint32(r.checksum.digest())
   466  			if inputChecksum != dataChecksum {
   467  				return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
   468  			}
   469  
   470  			r.blockOffset += 4
   471  		}
   472  		r.sawFrameHeader = false
   473  	}
   474  
   475  	return nil
   476  }
   477  
   478  // setBufferSize sets the decompressed buffer size.
   479  // When this is called the buffer is empty.
   480  func (r *Reader) setBufferSize(size int) {
   481  	if cap(r.buffer) < size {
   482  		need := size - cap(r.buffer)
   483  		r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
   484  	}
   485  	r.buffer = r.buffer[:size]
   486  }
   487  
   488  // zstdError is an error while decompressing.
   489  type zstdError struct {
   490  	offset int64
   491  	err    error
   492  }
   493  
   494  func (ze *zstdError) Error() string {
   495  	return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
   496  }
   497  
   498  func (ze *zstdError) Unwrap() error {
   499  	return ze.err
   500  }
   501  
   502  func (r *Reader) makeEOFError(off int) error {
   503  	return r.wrapError(off, io.ErrUnexpectedEOF)
   504  }
   505  
   506  func (r *Reader) wrapNonEOFError(off int, err error) error {
   507  	if err == io.EOF {
   508  		err = io.ErrUnexpectedEOF
   509  	}
   510  	return r.wrapError(off, err)
   511  }
   512  
   513  func (r *Reader) makeError(off int, msg string) error {
   514  	return r.wrapError(off, errors.New(msg))
   515  }
   516  
   517  func (r *Reader) wrapError(off int, err error) error {
   518  	if err == io.EOF {
   519  		return err
   520  	}
   521  	return &zstdError{r.blockOffset + int64(off), err}
   522  }
   523  

View as plain text