Source file src/internal/zstd/huff.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package zstd
     6  
     7  import (
     8  	"io"
     9  	"math/bits"
    10  )
    11  
    12  // maxHuffmanBits is the largest possible Huffman table bits.
    13  const maxHuffmanBits = 11
    14  
    15  // readHuff reads Huffman table from data starting at off into table.
    16  // Each entry in a Huffman table is a pair of bytes.
    17  // The high byte is the encoded value. The low byte is the number
    18  // of bits used to encode that value. We index into the table
    19  // with a value of size tableBits. A value that requires fewer bits
    20  // appear in the table multiple times.
    21  // This returns the number of bits in the Huffman table and the new offset.
    22  // RFC 4.2.1.
    23  func (r *Reader) readHuff(data block, off int, table []uint16) (tableBits, roff int, err error) {
    24  	if off >= len(data) {
    25  		return 0, 0, r.makeEOFError(off)
    26  	}
    27  
    28  	hdr := data[off]
    29  	off++
    30  
    31  	var weights [256]uint8
    32  	var count int
    33  	if hdr < 128 {
    34  		// The table is compressed using an FSE. RFC 4.2.1.2.
    35  		if len(r.fseScratch) < 1<<6 {
    36  			r.fseScratch = make([]fseEntry, 1<<6)
    37  		}
    38  		fseBits, noff, err := r.readFSE(data, off, 255, 6, r.fseScratch)
    39  		if err != nil {
    40  			return 0, 0, err
    41  		}
    42  		fseTable := r.fseScratch
    43  
    44  		if off+int(hdr) > len(data) {
    45  			return 0, 0, r.makeEOFError(off)
    46  		}
    47  
    48  		rbr, err := r.makeReverseBitReader(data, off+int(hdr)-1, noff)
    49  		if err != nil {
    50  			return 0, 0, err
    51  		}
    52  
    53  		state1, err := rbr.val(uint8(fseBits))
    54  		if err != nil {
    55  			return 0, 0, err
    56  		}
    57  
    58  		state2, err := rbr.val(uint8(fseBits))
    59  		if err != nil {
    60  			return 0, 0, err
    61  		}
    62  
    63  		// There are two independent FSE streams, tracked by
    64  		// state1 and state2. We decode them alternately.
    65  
    66  		for {
    67  			pt := &fseTable[state1]
    68  			if !rbr.fetch(pt.bits) {
    69  				if count >= 254 {
    70  					return 0, 0, rbr.makeError("Huffman count overflow")
    71  				}
    72  				weights[count] = pt.sym
    73  				weights[count+1] = fseTable[state2].sym
    74  				count += 2
    75  				break
    76  			}
    77  
    78  			v, err := rbr.val(pt.bits)
    79  			if err != nil {
    80  				return 0, 0, err
    81  			}
    82  			state1 = uint32(pt.base) + v
    83  
    84  			if count >= 255 {
    85  				return 0, 0, rbr.makeError("Huffman count overflow")
    86  			}
    87  
    88  			weights[count] = pt.sym
    89  			count++
    90  
    91  			pt = &fseTable[state2]
    92  
    93  			if !rbr.fetch(pt.bits) {
    94  				if count >= 254 {
    95  					return 0, 0, rbr.makeError("Huffman count overflow")
    96  				}
    97  				weights[count] = pt.sym
    98  				weights[count+1] = fseTable[state1].sym
    99  				count += 2
   100  				break
   101  			}
   102  
   103  			v, err = rbr.val(pt.bits)
   104  			if err != nil {
   105  				return 0, 0, err
   106  			}
   107  			state2 = uint32(pt.base) + v
   108  
   109  			if count >= 255 {
   110  				return 0, 0, rbr.makeError("Huffman count overflow")
   111  			}
   112  
   113  			weights[count] = pt.sym
   114  			count++
   115  		}
   116  
   117  		off += int(hdr)
   118  	} else {
   119  		// The table is not compressed. Each weight is 4 bits.
   120  
   121  		count = int(hdr) - 127
   122  		if off+((count+1)/2) >= len(data) {
   123  			return 0, 0, io.ErrUnexpectedEOF
   124  		}
   125  		for i := 0; i < count; i += 2 {
   126  			b := data[off]
   127  			off++
   128  			weights[i] = b >> 4
   129  			weights[i+1] = b & 0xf
   130  		}
   131  	}
   132  
   133  	// RFC 4.2.1.3.
   134  
   135  	var weightMark [13]uint32
   136  	weightMask := uint32(0)
   137  	for _, w := range weights[:count] {
   138  		if w > 12 {
   139  			return 0, 0, r.makeError(off, "Huffman weight overflow")
   140  		}
   141  		weightMark[w]++
   142  		if w > 0 {
   143  			weightMask += 1 << (w - 1)
   144  		}
   145  	}
   146  	if weightMask == 0 {
   147  		return 0, 0, r.makeError(off, "bad Huffman weights")
   148  	}
   149  
   150  	tableBits = 32 - bits.LeadingZeros32(weightMask)
   151  	if tableBits > maxHuffmanBits {
   152  		return 0, 0, r.makeError(off, "bad Huffman weights")
   153  	}
   154  
   155  	if len(table) < 1<<tableBits {
   156  		return 0, 0, r.makeError(off, "Huffman table too small")
   157  	}
   158  
   159  	// Work out the last weight value, which is omitted because
   160  	// the weights must sum to a power of two.
   161  	left := (uint32(1) << tableBits) - weightMask
   162  	if left == 0 {
   163  		return 0, 0, r.makeError(off, "bad Huffman weights")
   164  	}
   165  	highBit := 31 - bits.LeadingZeros32(left)
   166  	if uint32(1)<<highBit != left {
   167  		return 0, 0, r.makeError(off, "bad Huffman weights")
   168  	}
   169  	if count >= 256 {
   170  		return 0, 0, r.makeError(off, "Huffman weight overflow")
   171  	}
   172  	weights[count] = uint8(highBit + 1)
   173  	count++
   174  	weightMark[highBit+1]++
   175  
   176  	if weightMark[1] < 2 || weightMark[1]&1 != 0 {
   177  		return 0, 0, r.makeError(off, "bad Huffman weights")
   178  	}
   179  
   180  	// Change weightMark from a count of weights to the index of
   181  	// the first symbol for that weight. We shift the indexes to
   182  	// also store how many we have seen so far,
   183  	next := uint32(0)
   184  	for i := 0; i < tableBits; i++ {
   185  		cur := next
   186  		next += weightMark[i+1] << i
   187  		weightMark[i+1] = cur
   188  	}
   189  
   190  	for i, w := range weights[:count] {
   191  		if w == 0 {
   192  			continue
   193  		}
   194  		length := uint32(1) << (w - 1)
   195  		tval := uint16(i)<<8 | (uint16(tableBits) + 1 - uint16(w))
   196  		start := weightMark[w]
   197  		for j := uint32(0); j < length; j++ {
   198  			table[start+j] = tval
   199  		}
   200  		weightMark[w] += length
   201  	}
   202  
   203  	return tableBits, off, nil
   204  }
   205  

View as plain text