forked from Shiloh/githaven
401 lines
12 KiB
Go
401 lines
12 KiB
Go
|
// Copyright 2015, Joe Tsai. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE.md file.
|
||
|
|
||
|
// Package prefix implements bit readers and writers that use prefix encoding.
|
||
|
package prefix
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"sort"
|
||
|
|
||
|
"github.com/dsnet/compress/internal"
|
||
|
"github.com/dsnet/compress/internal/errors"
|
||
|
)
|
||
|
|
||
|
func errorf(c int, f string, a ...interface{}) error {
|
||
|
return errors.Error{Code: c, Pkg: "prefix", Msg: fmt.Sprintf(f, a...)}
|
||
|
}
|
||
|
|
||
|
func panicf(c int, f string, a ...interface{}) {
|
||
|
errors.Panic(errorf(c, f, a...))
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
countBits = 5 // Number of bits to store the bit-length of the code
|
||
|
valueBits = 27 // Number of bits to store the code value
|
||
|
|
||
|
countMask = (1 << countBits) - 1
|
||
|
)
|
||
|
|
||
|
// PrefixCode is a representation of a prefix code, which is conceptually a
|
||
|
// mapping from some arbitrary symbol to some bit-string.
|
||
|
//
|
||
|
// The Sym and Cnt fields are typically provided by the user,
|
||
|
// while the Len and Val fields are generated by this package.
|
||
|
type PrefixCode struct {
|
||
|
Sym uint32 // The symbol being mapped
|
||
|
Cnt uint32 // The number times this symbol is used
|
||
|
Len uint32 // Bit-length of the prefix code
|
||
|
Val uint32 // Value of the prefix code (must be in 0..(1<<Len)-1)
|
||
|
}
|
||
|
type PrefixCodes []PrefixCode
|
||
|
|
||
|
type prefixCodesBySymbol []PrefixCode
|
||
|
|
||
|
func (c prefixCodesBySymbol) Len() int { return len(c) }
|
||
|
func (c prefixCodesBySymbol) Less(i, j int) bool { return c[i].Sym < c[j].Sym }
|
||
|
func (c prefixCodesBySymbol) Swap(i, j int) { c[i], c[j] = c[j], c[i] }
|
||
|
|
||
|
type prefixCodesByCount []PrefixCode
|
||
|
|
||
|
func (c prefixCodesByCount) Len() int { return len(c) }
|
||
|
func (c prefixCodesByCount) Less(i, j int) bool {
|
||
|
return c[i].Cnt < c[j].Cnt || (c[i].Cnt == c[j].Cnt && c[i].Sym < c[j].Sym)
|
||
|
}
|
||
|
func (c prefixCodesByCount) Swap(i, j int) { c[i], c[j] = c[j], c[i] }
|
||
|
|
||
|
func (pc PrefixCodes) SortBySymbol() { sort.Sort(prefixCodesBySymbol(pc)) }
|
||
|
func (pc PrefixCodes) SortByCount() { sort.Sort(prefixCodesByCount(pc)) }
|
||
|
|
||
|
// Length computes the total bit-length using the Len and Cnt fields.
|
||
|
func (pc PrefixCodes) Length() (nb uint) {
|
||
|
for _, c := range pc {
|
||
|
nb += uint(c.Len * c.Cnt)
|
||
|
}
|
||
|
return nb
|
||
|
}
|
||
|
|
||
|
// checkLengths reports whether the codes form a complete prefix tree.
|
||
|
func (pc PrefixCodes) checkLengths() bool {
|
||
|
sum := 1 << valueBits
|
||
|
for _, c := range pc {
|
||
|
sum -= (1 << valueBits) >> uint(c.Len)
|
||
|
}
|
||
|
return sum == 0 || len(pc) == 0
|
||
|
}
|
||
|
|
||
|
// checkPrefixes reports whether all codes have non-overlapping prefixes.
|
||
|
func (pc PrefixCodes) checkPrefixes() bool {
|
||
|
for i, c1 := range pc {
|
||
|
for j, c2 := range pc {
|
||
|
mask := uint32(1)<<c1.Len - 1
|
||
|
if i != j && c1.Len <= c2.Len && c1.Val&mask == c2.Val&mask {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
// checkCanonical reports whether all codes are canonical.
|
||
|
// That is, they have the following properties:
|
||
|
//
|
||
|
// 1. All codes of a given bit-length are consecutive values.
|
||
|
// 2. Shorter codes lexicographically precede longer codes.
|
||
|
//
|
||
|
// The codes must have unique symbols and be sorted by the symbol
|
||
|
// The Len and Val fields in each code must be populated.
|
||
|
func (pc PrefixCodes) checkCanonical() bool {
|
||
|
// Rule 1.
|
||
|
var vals [valueBits + 1]PrefixCode
|
||
|
for _, c := range pc {
|
||
|
if c.Len > 0 {
|
||
|
c.Val = internal.ReverseUint32N(c.Val, uint(c.Len))
|
||
|
if vals[c.Len].Cnt > 0 && vals[c.Len].Val+1 != c.Val {
|
||
|
return false
|
||
|
}
|
||
|
vals[c.Len].Val = c.Val
|
||
|
vals[c.Len].Cnt++
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Rule 2.
|
||
|
var last PrefixCode
|
||
|
for _, v := range vals {
|
||
|
if v.Cnt > 0 {
|
||
|
curVal := v.Val - v.Cnt + 1
|
||
|
if last.Cnt != 0 && last.Val >= curVal {
|
||
|
return false
|
||
|
}
|
||
|
last = v
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
// GenerateLengths assigns non-zero bit-lengths to all codes. Codes with high
|
||
|
// frequency counts will be assigned shorter codes to reduce bit entropy.
|
||
|
// This function is used primarily by compressors.
|
||
|
//
|
||
|
// The input codes must have the Cnt field populated, be sorted by count.
|
||
|
// Even if a code has a count of 0, a non-zero bit-length will be assigned.
|
||
|
//
|
||
|
// The result will have the Len field populated. The algorithm used guarantees
|
||
|
// that Len <= maxBits and that it is a complete prefix tree. The resulting
|
||
|
// codes will remain sorted by count.
|
||
|
func GenerateLengths(codes PrefixCodes, maxBits uint) error {
|
||
|
if len(codes) <= 1 {
|
||
|
if len(codes) == 1 {
|
||
|
codes[0].Len = 0
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Verify that the codes are in ascending order by count.
|
||
|
cntLast := codes[0].Cnt
|
||
|
for _, c := range codes[1:] {
|
||
|
if c.Cnt < cntLast {
|
||
|
return errorf(errors.Invalid, "non-monotonically increasing symbol counts")
|
||
|
}
|
||
|
cntLast = c.Cnt
|
||
|
}
|
||
|
|
||
|
// Construct a Huffman tree used to generate the bit-lengths.
|
||
|
//
|
||
|
// The Huffman tree is a binary tree where each symbol lies as a leaf node
|
||
|
// on this tree. The length of the prefix code to assign is the depth of
|
||
|
// that leaf from the root. The Huffman algorithm, which runs in O(n),
|
||
|
// is used to generate the tree. It assumes that codes are sorted in
|
||
|
// increasing order of frequency.
|
||
|
//
|
||
|
// The algorithm is as follows:
|
||
|
// 1. Start with two queues, F and Q, where F contains all of the starting
|
||
|
// symbols sorted such that symbols with lowest counts come first.
|
||
|
// 2. While len(F)+len(Q) > 1:
|
||
|
// 2a. Dequeue the node from F or Q that has the lowest weight as N0.
|
||
|
// 2b. Dequeue the node from F or Q that has the lowest weight as N1.
|
||
|
// 2c. Create a new node N that has N0 and N1 as its children.
|
||
|
// 2d. Enqueue N into the back of Q.
|
||
|
// 3. The tree's root node is Q[0].
|
||
|
type node struct {
|
||
|
cnt uint32
|
||
|
|
||
|
// n0 or c0 represent the left child of this node.
|
||
|
// Since Go does not have unions, only one of these will be set.
|
||
|
// Similarly, n1 or c1 represent the right child of this node.
|
||
|
//
|
||
|
// If n0 or n1 is set, then it represents a "pointer" to another
|
||
|
// node in the Huffman tree. Since Go's pointer analysis cannot reason
|
||
|
// that these node pointers do not escape (golang.org/issue/13493),
|
||
|
// we use an index to a node in the nodes slice as a pseudo-pointer.
|
||
|
//
|
||
|
// If c0 or c1 is set, then it represents a leaf "node" in the
|
||
|
// Huffman tree. The leaves are the PrefixCode values themselves.
|
||
|
n0, n1 int // Index to child nodes
|
||
|
c0, c1 *PrefixCode
|
||
|
}
|
||
|
var nodeIdx int
|
||
|
var nodeArr [1024]node // Large enough to handle most cases on the stack
|
||
|
nodes := nodeArr[:]
|
||
|
if len(nodes) < len(codes) {
|
||
|
nodes = make([]node, len(codes)) // Number of internal nodes < number of leaves
|
||
|
}
|
||
|
freqs, queue := codes, nodes[:0]
|
||
|
for len(freqs)+len(queue) > 1 {
|
||
|
// These are the two smallest nodes at the front of freqs and queue.
|
||
|
var n node
|
||
|
if len(queue) == 0 || (len(freqs) > 0 && freqs[0].Cnt <= queue[0].cnt) {
|
||
|
n.c0, freqs = &freqs[0], freqs[1:]
|
||
|
n.cnt += n.c0.Cnt
|
||
|
} else {
|
||
|
n.cnt += queue[0].cnt
|
||
|
n.n0 = nodeIdx // nodeIdx is same as &queue[0] - &nodes[0]
|
||
|
nodeIdx++
|
||
|
queue = queue[1:]
|
||
|
}
|
||
|
if len(queue) == 0 || (len(freqs) > 0 && freqs[0].Cnt <= queue[0].cnt) {
|
||
|
n.c1, freqs = &freqs[0], freqs[1:]
|
||
|
n.cnt += n.c1.Cnt
|
||
|
} else {
|
||
|
n.cnt += queue[0].cnt
|
||
|
n.n1 = nodeIdx // nodeIdx is same as &queue[0] - &nodes[0]
|
||
|
nodeIdx++
|
||
|
queue = queue[1:]
|
||
|
}
|
||
|
queue = append(queue, n)
|
||
|
}
|
||
|
rootIdx := nodeIdx
|
||
|
|
||
|
// Search the whole binary tree, noting when we hit each leaf node.
|
||
|
// We do not care about the exact Huffman tree structure, but rather we only
|
||
|
// care about depth of each of the leaf nodes. That is, the depth determines
|
||
|
// how long each symbol is in bits.
|
||
|
//
|
||
|
// Since the number of leaves is n, there is at most n internal nodes.
|
||
|
// Thus, this algorithm runs in O(n).
|
||
|
var fixBits bool
|
||
|
var explore func(int, uint)
|
||
|
explore = func(rootIdx int, level uint) {
|
||
|
root := &nodes[rootIdx]
|
||
|
|
||
|
// Explore left branch.
|
||
|
if root.c0 == nil {
|
||
|
explore(root.n0, level+1)
|
||
|
} else {
|
||
|
fixBits = fixBits || (level > maxBits)
|
||
|
root.c0.Len = uint32(level)
|
||
|
}
|
||
|
|
||
|
// Explore right branch.
|
||
|
if root.c1 == nil {
|
||
|
explore(root.n1, level+1)
|
||
|
} else {
|
||
|
fixBits = fixBits || (level > maxBits)
|
||
|
root.c1.Len = uint32(level)
|
||
|
}
|
||
|
}
|
||
|
explore(rootIdx, 1)
|
||
|
|
||
|
// Fix the bit-lengths if we violate the maxBits requirement.
|
||
|
if fixBits {
|
||
|
// Create histogram for number of symbols with each bit-length.
|
||
|
var symBitsArr [valueBits + 1]uint32
|
||
|
symBits := symBitsArr[:] // symBits[nb] indicates number of symbols using nb bits
|
||
|
for _, c := range codes {
|
||
|
for int(c.Len) >= len(symBits) {
|
||
|
symBits = append(symBits, 0)
|
||
|
}
|
||
|
symBits[c.Len]++
|
||
|
}
|
||
|
|
||
|
// Fudge the tree such that the largest bit-length is <= maxBits.
|
||
|
// This is accomplish by effectively doing a tree rotation. That is, we
|
||
|
// increase the bit-length of some higher frequency code, so that the
|
||
|
// bit-lengths of lower frequency codes can be decreased.
|
||
|
//
|
||
|
// Visually, this looks like the following transform:
|
||
|
//
|
||
|
// Level Before After
|
||
|
// __ ___
|
||
|
// / \ / \
|
||
|
// n-1 X / \ /\ /\
|
||
|
// n X /\ X X X X
|
||
|
// n+1 X X
|
||
|
//
|
||
|
var treeRotate func(uint)
|
||
|
treeRotate = func(nb uint) {
|
||
|
if symBits[nb-1] == 0 {
|
||
|
treeRotate(nb - 1)
|
||
|
}
|
||
|
symBits[nb-1] -= 1 // Push this node to the level below
|
||
|
symBits[nb] += 3 // This level gets one node from above, two from below
|
||
|
symBits[nb+1] -= 2 // Push two nodes to the level above
|
||
|
}
|
||
|
for i := uint(len(symBits)) - 1; i > maxBits; i-- {
|
||
|
for symBits[i] > 0 {
|
||
|
treeRotate(i - 1)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Assign bit-lengths to each code. Since codes is sorted in increasing
|
||
|
// order of frequency, that means that the most frequently used symbols
|
||
|
// should have the shortest bit-lengths. Thus, we copy symbols to codes
|
||
|
// from the back of codes first.
|
||
|
cs := codes
|
||
|
for nb, cnt := range symBits {
|
||
|
if cnt > 0 {
|
||
|
pos := len(cs) - int(cnt)
|
||
|
cs2 := cs[pos:]
|
||
|
for i := range cs2 {
|
||
|
cs2[i].Len = uint32(nb)
|
||
|
}
|
||
|
cs = cs[:pos]
|
||
|
}
|
||
|
}
|
||
|
if len(cs) != 0 {
|
||
|
panic("not all codes were used up")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if internal.Debug && !codes.checkLengths() {
|
||
|
panic("incomplete prefix tree detected")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// GeneratePrefixes assigns a prefix value to all codes according to the
|
||
|
// bit-lengths. This function is used by both compressors and decompressors.
|
||
|
//
|
||
|
// The input codes must have the Sym and Len fields populated and be
|
||
|
// sorted by symbol. The bit-lengths of each code must be properly allocated,
|
||
|
// such that it forms a complete tree.
|
||
|
//
|
||
|
// The result will have the Val field populated and will produce a canonical
|
||
|
// prefix tree. The resulting codes will remain sorted by symbol.
|
||
|
func GeneratePrefixes(codes PrefixCodes) error {
|
||
|
if len(codes) <= 1 {
|
||
|
if len(codes) == 1 {
|
||
|
if codes[0].Len != 0 {
|
||
|
return errorf(errors.Invalid, "degenerate prefix tree with one node")
|
||
|
}
|
||
|
codes[0].Val = 0
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Compute basic statistics on the symbols.
|
||
|
var bitCnts [valueBits + 1]uint
|
||
|
c0 := codes[0]
|
||
|
bitCnts[c0.Len]++
|
||
|
minBits, maxBits, symLast := c0.Len, c0.Len, c0.Sym
|
||
|
for _, c := range codes[1:] {
|
||
|
if c.Sym <= symLast {
|
||
|
return errorf(errors.Invalid, "non-unique or non-monotonically increasing symbols")
|
||
|
}
|
||
|
if minBits > c.Len {
|
||
|
minBits = c.Len
|
||
|
}
|
||
|
if maxBits < c.Len {
|
||
|
maxBits = c.Len
|
||
|
}
|
||
|
bitCnts[c.Len]++ // Histogram of bit counts
|
||
|
symLast = c.Sym // Keep track of last symbol
|
||
|
}
|
||
|
if minBits == 0 {
|
||
|
return errorf(errors.Invalid, "invalid prefix bit-length")
|
||
|
}
|
||
|
|
||
|
// Compute the next code for a symbol of a given bit length.
|
||
|
var nextCodes [valueBits + 1]uint
|
||
|
var code uint
|
||
|
for i := minBits; i <= maxBits; i++ {
|
||
|
code <<= 1
|
||
|
nextCodes[i] = code
|
||
|
code += bitCnts[i]
|
||
|
}
|
||
|
if code != 1<<maxBits {
|
||
|
return errorf(errors.Invalid, "degenerate prefix tree")
|
||
|
}
|
||
|
|
||
|
// Assign the code to each symbol.
|
||
|
for i, c := range codes {
|
||
|
codes[i].Val = internal.ReverseUint32N(uint32(nextCodes[c.Len]), uint(c.Len))
|
||
|
nextCodes[c.Len]++
|
||
|
}
|
||
|
|
||
|
if internal.Debug && !codes.checkPrefixes() {
|
||
|
panic("overlapping prefixes detected")
|
||
|
}
|
||
|
if internal.Debug && !codes.checkCanonical() {
|
||
|
panic("non-canonical prefixes detected")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func allocUint32s(s []uint32, n int) []uint32 {
|
||
|
if cap(s) >= n {
|
||
|
return s[:n]
|
||
|
}
|
||
|
return make([]uint32, n, n*3/2)
|
||
|
}
|
||
|
|
||
|
func extendSliceUint32s(s [][]uint32, n int) [][]uint32 {
|
||
|
if cap(s) >= n {
|
||
|
return s[:n]
|
||
|
}
|
||
|
ss := make([][]uint32, n, n*3/2)
|
||
|
copy(ss, s[:cap(s)])
|
||
|
return ss
|
||
|
}
|