source: trunk/Mars/mcore/huffman.h@ 14869

Last change on this file since 14869 was 14795, checked in by tbretz, 12 years ago
File size: 10.0 KB
Line 
1#ifndef FACT_huffman
2#define FACT_huffman
3
4#include <string.h>
5#include <stdint.h>
6
7#include <set>
8#include <string>
9#include <vector>
10
11#define MAX_SYMBOLS (1<<16)
12
13// ================================================================
14
15namespace Huffman
16{
17 static unsigned long numbytes_from_numbits(unsigned long numbits)
18 {
19 return numbits / 8 + (numbits % 8 ? 1 : 0);
20 }
21
22 struct TreeNode
23 {
24 TreeNode *parent;
25 union
26 {
27 struct
28 {
29 TreeNode *zero, *one;
30 };
31 uint16_t symbol;
32 };
33
34 size_t count;
35 bool isLeaf;
36
37 TreeNode(uint16_t sym, size_t cnt=0) : parent(0), isLeaf(true)
38 {
39 symbol = sym;
40 count = cnt;
41 }
42
43 TreeNode(TreeNode *n0=0, TreeNode *n1=0) : parent(0), isLeaf(false)
44 {
45 count = n0 && n1 ? n0->count + n1->count : 0;
46 zero = n0 && n1 ? (n0->count > n1->count ? n0 : n1) : NULL;
47 one = n0 && n1 ? (n0->count > n1->count ? n1 : n0) : NULL;
48
49 if (n0)
50 n0->parent = this;
51
52 if (n1)
53 n1->parent = this;
54 }
55
56 ~TreeNode()
57 {
58 if (isLeaf)
59 return;
60
61 if (zero)
62 delete zero;
63 if (one)
64 delete one;
65 }
66
67 bool operator() (const TreeNode *hn1, const TreeNode *hn2) const
68 {
69 return hn1->count < hn2->count;
70 }
71 };
72
73
74 struct Encoder
75 {
76 struct Code
77 {
78 size_t bits;
79 uint8_t numbits;
80
81 Code() : numbits(0) { }
82 };
83
84 size_t count;
85 Code lut[1<<16];
86
87 void CreateEncoder(const TreeNode *n, size_t bits=0, uint8_t nbits=0)
88 {
89 if (n->isLeaf)
90 {
91 if (nbits>sizeof(size_t)*8)
92 throw std::runtime_error("Too many different symbols - this should not happen!");
93
94 lut[n->symbol].bits = bits;
95 lut[n->symbol].numbits = nbits==0 ? 1 : nbits;
96 count++;
97 return;
98 }
99
100 CreateEncoder(n->zero, bits, nbits+1);
101 CreateEncoder(n->one, bits | (1<<nbits), nbits+1);
102 }
103
104 void WriteCodeTable(std::string &out) const
105 {
106 out.append((char*)&count, sizeof(size_t));
107
108 for (uint32_t i=0; i<MAX_SYMBOLS; i++)
109 {
110 const Code &n = lut[i];
111 if (n.numbits==0)
112 continue;
113
114 // Write the 2 byte symbol.
115 out.append((char*)&i, sizeof(uint16_t));
116 if (count==1)
117 return;
118
119 // Write the 1 byte code bit length.
120 out.append((char*)&n.numbits, sizeof(uint8_t));
121
122 // Write the code bytes.
123 uint32_t numbytes = numbytes_from_numbits(n.numbits);
124 out.append((char*)&n.bits, numbytes);
125 }
126 }
127
128 void Encode(std::string &out, const uint16_t *bufin, size_t bufinlen) const
129 {
130 if (count==1)
131 return;
132
133 uint8_t curbyte = 0;
134 uint8_t curbit = 0;
135
136 for (uint32_t i=0; i<bufinlen; ++i)
137 {
138 const uint16_t &symbol = bufin[i];
139
140 const Code *code = lut+symbol;
141
142 uint8_t nbits = code->numbits;
143 const uint8_t *bits = (uint8_t*)&code->bits;
144
145 while (nbits>0)
146 {
147 // Number of bits available in the current byte
148 const uint8_t free_bits = 8 - curbit;
149
150 // Write bits to current byte
151 curbyte |= *bits<<curbit;
152
153 // If the byte has been filled, put it into the output buffer
154 // If the bits exceed the current byte step to the next byte
155 // and fill it properly
156 if (nbits>=free_bits)
157 {
158 out += curbyte;
159 curbyte = *bits>>free_bits;
160
161 bits++;
162 }
163
164 // Adapt the number of available bits, the number of consumed bits
165 // and the bit-pointer accordingly
166 const uint8_t consumed = nbits>8 ? 8 : nbits;
167 nbits -= consumed;
168 curbit += consumed;
169 curbit %= 8;
170 }
171 }
172
173 // If the buffer-byte is half-full, also add it to the output buffer
174 if (curbit>0)
175 out += curbyte;
176 }
177
178 Encoder(const uint16_t *bufin, size_t bufinlen) : count(0)
179 {
180 uint16_t counts[MAX_SYMBOLS];
181 memset(counts, 0, sizeof(uint16_t)*MAX_SYMBOLS);
182
183 // Count occurances
184 for (const uint16_t *p=bufin; p<bufin+bufinlen; p++)
185 counts[*p]++;
186
187 // Copy all occuring symbols into a sorted list
188 std::multiset<TreeNode*, TreeNode> set;
189 for (int i=0; i<MAX_SYMBOLS; i++)
190 if (counts[i])
191 set.insert(new TreeNode(i, counts[i]));
192
193 // Create the tree bottom-up
194 while (set.size()>1)
195 {
196 auto it = set.begin();
197
198 auto it1 = it++;
199 auto it2 = it;
200
201 TreeNode *nn = new TreeNode(*it1, *it2);
202
203 set.erase(it1, ++it2);
204
205 set.insert(nn);
206 }
207
208 // get the root of the tree
209 const TreeNode *root = *set.begin();
210
211 CreateEncoder(root);
212
213 // This will delete the whole tree
214 delete root;
215 }
216
217 };
218
219
220
221 struct Decoder
222 {
223 uint16_t symbol;
224 uint8_t nbits;
225 bool isLeaf;
226
227 Decoder *lut;
228
229 Decoder() : isLeaf(false), lut(NULL)
230 {
231 }
232
233 ~Decoder()
234 {
235 if (lut)
236 delete [] lut;
237 }
238
239 void Set(uint16_t sym, uint8_t n=0, size_t bits=0)
240 {
241 if (!lut)
242 lut = new Decoder[256];
243
244 if (n>8)
245 {
246 lut[bits&0xff].Set(sym, n-8, bits>>8);
247 return;
248 }
249
250 const int nn = 1<<(8-n);
251
252 for (int i=0; i<nn; i++)
253 {
254 const uint8_t key = bits | (i<<n);
255
256 lut[key].symbol = sym;
257 lut[key].isLeaf = true;
258 lut[key].nbits = n;
259 }
260 }
261
262 void Build(const TreeNode &p, uint64_t bits=0, uint8_t n=0)
263 {
264 if (p.isLeaf)
265 {
266 Set(p.symbol, n, bits);
267 return;
268 }
269
270 Build(*p.zero, bits, n+1);
271 Build(*p.one, bits | (1<<n), n+1);
272 }
273
274 Decoder(const TreeNode &p) : symbol(0), isLeaf(false), lut(NULL)
275 {
276 Build(p);
277 }
278
279 const uint8_t *Decode(const uint8_t *in_ptr, const uint8_t *in_end,
280 uint16_t *out_ptr, const uint16_t *out_end) const
281 {
282 Decoder const *p = this;
283
284 uint8_t curbit = 0;
285 while (in_ptr<in_end && out_ptr<out_end)
286 {
287 const uint16_t *two = (uint16_t*)in_ptr;
288
289 const uint8_t curbyte = (*two >> curbit);
290
291 if (!p->lut)
292 throw std::runtime_error("Unknown bitcode in stream!");
293 p = p->lut + curbyte;
294 if (!p->isLeaf)
295 {
296 in_ptr++;
297 continue;
298 }
299
300 *out_ptr++ = p->symbol;
301 curbit += p->nbits;
302
303 p = this;
304
305 if (curbit>=8)
306 {
307 curbit %= 8;
308 in_ptr++;
309 }
310
311 }
312
313 return curbit ? in_ptr+1 : in_ptr;
314 }
315
316 Decoder(const uint8_t* bufin, unsigned int &pindex) : isLeaf(false), lut(NULL)
317 {
318 // FIXME: Sanity check for size missing....
319
320 // Read the number of entries.
321 size_t count=0;
322 memcpy(&count, bufin + pindex, sizeof(count));
323 pindex += sizeof(count);
324
325 // Read the entries.
326 for (size_t i=0; i<count; i++)
327 {
328 uint16_t sym;
329 memcpy(&sym, bufin + pindex, sizeof(uint16_t));
330 pindex += sizeof(uint16_t);
331
332 if (count==1)
333 {
334 Set(sym);
335 break;
336 }
337
338 uint8_t numbits;
339 memcpy(&numbits, bufin + pindex, sizeof(uint8_t));
340 pindex += sizeof(uint8_t);
341
342 const uint8_t numbytes = numbytes_from_numbits(numbits);
343 if (numbytes>sizeof(size_t))
344 throw std::runtime_error("Number of bytes for a single symbol exceeds maximum.");
345
346 size_t bits=0;
347 memcpy(&bits, bufin+pindex, numbytes);
348 pindex += numbytes;
349
350 Set(sym, numbits, bits);
351 }
352 }
353 };
354
355 void Encode(std::string &bufout, const uint16_t *bufin, size_t bufinlen)
356 {
357 bufout.append((char*)&bufinlen, sizeof(size_t));
358
359 const Encoder encoder(bufin, bufinlen);
360 encoder.WriteCodeTable(bufout);
361 encoder.Encode(bufout, bufin, bufinlen);
362 }
363
364 int Decode(const uint8_t *bufin,
365 size_t bufinlen,
366 std::vector<uint16_t> &pbufout)
367 {
368 unsigned int i = 0;
369
370 // Read the number of data bytes this encoding represents.
371 size_t data_count = 0;
372 memcpy(&data_count, bufin, sizeof(size_t));
373 i += sizeof(size_t);
374
375
376 pbufout.resize(data_count);
377
378 const Decoder decoder(bufin, i);
379
380 const uint8_t *in_ptr =
381 decoder.Decode(bufin+i, bufin+bufinlen,
382 pbufout.data(), pbufout.data()+data_count);
383
384 return in_ptr-bufin;
385 }
386};
387
388#endif
Note: See TracBrowser for help on using the repository browser.