source: trunk/Mars/mcore/huffman.h@ 18123

Last change on this file since 18123 was 17142, checked in by tbretz, 11 years ago
The two functions need an inline tag, otherwise they can end up in two different object files which then cannot belinked together.
File size: 10.9 KB
Line 
1#ifndef FACT_huffman
2#define FACT_huffman
3
4#include <string.h>
5#include <stdint.h>
6
7#include <set>
8#include <string>
9#include <vector>
10
11#define MAX_SYMBOLS (1<<16)
12
13// ================================================================
14
15namespace Huffman
16{
17 static unsigned long numbytes_from_numbits(unsigned long numbits)
18 {
19 return numbits / 8 + (numbits % 8 ? 1 : 0);
20 }
21
22 struct TreeNode
23 {
24 TreeNode *parent;
25 union
26 {
27 struct
28 {
29 TreeNode *zero, *one;
30 };
31 uint16_t symbol;
32 };
33
34 size_t count;
35 bool isLeaf;
36
37 TreeNode(uint16_t sym, size_t cnt=0) : parent(0), isLeaf(true)
38 {
39 symbol = sym;
40 count = cnt;
41 }
42
43 TreeNode(TreeNode *n0=0, TreeNode *n1=0) : parent(0), isLeaf(false)
44 {
45 count = n0 && n1 ? n0->count + n1->count : 0;
46 zero = n0 && n1 ? (n0->count > n1->count ? n0 : n1) : NULL;
47 one = n0 && n1 ? (n0->count > n1->count ? n1 : n0) : NULL;
48
49 if (n0)
50 n0->parent = this;
51
52 if (n1)
53 n1->parent = this;
54 }
55
56 ~TreeNode()
57 {
58 if (isLeaf)
59 return;
60
61 if (zero)
62 delete zero;
63 if (one)
64 delete one;
65 }
66
67 bool operator() (const TreeNode *hn1, const TreeNode *hn2) const
68 {
69 return hn1->count < hn2->count;
70 }
71 };
72
73
74 struct Encoder
75 {
76 struct Code
77 {
78 size_t bits;
79 uint8_t numbits;
80
81 Code() : numbits(0) { }
82 };
83
84 size_t count;
85 Code lut[1<<16];
86
87 bool CreateEncoder(const TreeNode *n, size_t bits=0, uint8_t nbits=0)
88 {
89 if (n->isLeaf)
90 {
91#ifdef __EXCEPTIONS
92 if (nbits>sizeof(size_t)*8)
93 throw std::runtime_error("Too many different symbols - this should not happen!");
94#else
95 if (nbits>sizeof(size_t)*8)
96 {
97 count = 0;
98 return false;
99 }
100#endif
101 lut[n->symbol].bits = bits;
102 lut[n->symbol].numbits = nbits==0 ? 1 : nbits;
103 count++;
104 return true;
105 }
106
107 return
108 CreateEncoder(n->zero, bits, nbits+1) &&
109 CreateEncoder(n->one, bits | (1<<nbits), nbits+1);
110
111 }
112
113 void WriteCodeTable(std::string &out) const
114 {
115 out.append((char*)&count, sizeof(size_t));
116
117 for (uint32_t i=0; i<MAX_SYMBOLS; i++)
118 {
119 const Code &n = lut[i];
120 if (n.numbits==0)
121 continue;
122
123 // Write the 2 byte symbol.
124 out.append((char*)&i, sizeof(uint16_t));
125 if (count==1)
126 return;
127
128 // Write the 1 byte code bit length.
129 out.append((char*)&n.numbits, sizeof(uint8_t));
130
131 // Write the code bytes.
132 uint32_t numbytes = numbytes_from_numbits(n.numbits);
133 out.append((char*)&n.bits, numbytes);
134 }
135 }
136
137 void Encode(std::string &out, const uint16_t *bufin, size_t bufinlen) const
138 {
139 if (count==1)
140 return;
141
142 uint8_t curbyte = 0;
143 uint8_t curbit = 0;
144
145 for (uint32_t i=0; i<bufinlen; ++i)
146 {
147 const uint16_t &symbol = bufin[i];
148
149 const Code *code = lut+symbol;
150
151 uint8_t nbits = code->numbits;
152 const uint8_t *bits = (uint8_t*)&code->bits;
153
154 while (nbits>0)
155 {
156 // Number of bits available in the current byte
157 const uint8_t free_bits = 8 - curbit;
158
159 // Write bits to current byte
160 curbyte |= *bits<<curbit;
161
162 // If the byte has been filled, put it into the output buffer
163 // If the bits exceed the current byte step to the next byte
164 // and fill it properly
165 if (nbits>=free_bits)
166 {
167 out += curbyte;
168 curbyte = *bits>>free_bits;
169
170 bits++;
171 }
172
173 // Adapt the number of available bits, the number of consumed bits
174 // and the bit-pointer accordingly
175 const uint8_t consumed = nbits>8 ? 8 : nbits;
176 nbits -= consumed;
177 curbit += consumed;
178 curbit %= 8;
179 }
180 }
181
182 // If the buffer-byte is half-full, also add it to the output buffer
183 if (curbit>0)
184 out += curbyte;
185 }
186
187 Encoder(const uint16_t *bufin, size_t bufinlen) : count(0)
188 {
189 uint64_t counts[MAX_SYMBOLS];
190 memset(counts, 0, sizeof(uint64_t)*MAX_SYMBOLS);
191
192 // Count occurances
193 for (const uint16_t *p=bufin; p<bufin+bufinlen; p++)
194 counts[*p]++;
195
196 // Copy all occuring symbols into a sorted list
197 std::multiset<TreeNode*, TreeNode> set;
198 for (int i=0; i<MAX_SYMBOLS; i++)
199 if (counts[i])
200 set.insert(new TreeNode(i, counts[i]));
201
202 // Create the tree bottom-up
203 while (set.size()>1)
204 {
205 auto it = set.begin();
206
207 auto it1 = it++;
208 auto it2 = it;
209
210 TreeNode *nn = new TreeNode(*it1, *it2);
211
212 set.erase(it1, ++it2);
213
214 set.insert(nn);
215 }
216
217 // get the root of the tree
218 const TreeNode *root = *set.begin();
219
220 CreateEncoder(root);
221
222 // This will delete the whole tree
223 delete root;
224 }
225
226 };
227
228
229
230 struct Decoder
231 {
232 uint16_t symbol;
233 uint8_t nbits;
234 bool isLeaf;
235
236 Decoder *lut;
237
238 Decoder() : isLeaf(false), lut(NULL)
239 {
240 }
241
242 ~Decoder()
243 {
244 if (lut)
245 delete [] lut;
246 }
247
248 void Set(uint16_t sym, uint8_t n=0, size_t bits=0)
249 {
250 if (!lut)
251 lut = new Decoder[256];
252
253 if (n>8)
254 {
255 lut[bits&0xff].Set(sym, n-8, bits>>8);
256 return;
257 }
258
259 const int nn = 1<<(8-n);
260
261 for (int i=0; i<nn; i++)
262 {
263 const uint8_t key = bits | (i<<n);
264
265 lut[key].symbol = sym;
266 lut[key].isLeaf = true;
267 lut[key].nbits = n;
268 }
269 }
270
271 void Build(const TreeNode &p, uint64_t bits=0, uint8_t n=0)
272 {
273 if (p.isLeaf)
274 {
275 Set(p.symbol, n, bits);
276 return;
277 }
278
279 Build(*p.zero, bits, n+1);
280 Build(*p.one, bits | (1<<n), n+1);
281 }
282
283 Decoder(const TreeNode &p) : symbol(0), isLeaf(false), lut(NULL)
284 {
285 Build(p);
286 }
287
288 const uint8_t *Decode(const uint8_t *in_ptr, const uint8_t *in_end,
289 uint16_t *out_ptr, const uint16_t *out_end) const
290 {
291 Decoder const *p = this;
292
293 if (in_ptr==in_end)
294 {
295 while (out_ptr < out_end)
296 *out_ptr++ = p->lut->symbol;
297 return in_ptr;
298 }
299
300 uint8_t curbit = 0;
301 while (in_ptr<in_end && out_ptr<out_end)
302 {
303 const uint16_t *two = (uint16_t*)in_ptr;
304
305 const uint8_t curbyte = (*two >> curbit);
306
307#ifdef __EXCEPTIONS
308 if (!p->lut)
309 throw std::runtime_error("Unknown bitcode in stream!");
310#else
311 if (!p->lut)
312 return NULL;
313#endif
314
315 p = p->lut + curbyte;
316 if (!p->isLeaf)
317 {
318 in_ptr++;
319 continue;
320 }
321
322 *out_ptr++ = p->symbol;
323 curbit += p->nbits;
324
325 p = this;
326
327 if (curbit>=8)
328 {
329 curbit %= 8;
330 in_ptr++;
331 }
332
333 }
334
335 return curbit ? in_ptr+1 : in_ptr;
336 }
337
338 Decoder(const uint8_t* bufin, int64_t &pindex) : isLeaf(false), lut(NULL)
339 {
340 // FIXME: Sanity check for size missing....
341
342 // Read the number of entries.
343 size_t count=0;
344 memcpy(&count, bufin + pindex, sizeof(count));
345 pindex += sizeof(count);
346
347 // Read the entries.
348 for (size_t i=0; i<count; i++)
349 {
350 uint16_t sym;
351 memcpy(&sym, bufin + pindex, sizeof(uint16_t));
352 pindex += sizeof(uint16_t);
353
354 if (count==1)
355 {
356 Set(sym);
357 break;
358 }
359
360 uint8_t numbits;
361 memcpy(&numbits, bufin + pindex, sizeof(uint8_t));
362 pindex += sizeof(uint8_t);
363
364 const uint8_t numbytes = numbytes_from_numbits(numbits);
365
366#ifdef __EXCEPTIONS
367 if (numbytes>sizeof(size_t))
368 throw std::runtime_error("Number of bytes for a single symbol exceeds maximum.");
369#else
370 if (numbytes>sizeof(size_t))
371 {
372 pindex = -1;
373 return;
374 }
375#endif
376 size_t bits=0;
377 memcpy(&bits, bufin+pindex, numbytes);
378 pindex += numbytes;
379
380 Set(sym, numbits, bits);
381 }
382 }
383 };
384
385 inline bool Encode(std::string &bufout, const uint16_t *bufin, size_t bufinlen)
386 {
387 const Encoder encoder(bufin, bufinlen);
388
389#ifndef __EXCEPTIONS
390 if (encoder.count==0)
391 return false;
392#endif
393
394 bufout.append((char*)&bufinlen, sizeof(size_t));
395 encoder.WriteCodeTable(bufout);
396 encoder.Encode(bufout, bufin, bufinlen);
397
398 return true;
399 }
400
401 inline int64_t Decode(const uint8_t *bufin,
402 size_t bufinlen,
403 std::vector<uint16_t> &pbufout)
404 {
405 int64_t i = 0;
406
407 // Read the number of data bytes this encoding represents.
408 size_t data_count = 0;
409 memcpy(&data_count, bufin, sizeof(size_t));
410 i += sizeof(size_t);
411
412
413 pbufout.resize(data_count);
414
415 const Decoder decoder(bufin, i);
416
417#ifndef __EXCEPTIONS
418 if (i==-1)
419 return -1;
420#endif
421
422 const uint8_t *in_ptr =
423 decoder.Decode(bufin+i, bufin+bufinlen,
424 pbufout.data(), pbufout.data()+data_count);
425
426#ifndef __EXCEPTIONS
427 if (!in_ptr)
428 return -1;
429#endif
430
431 return in_ptr-bufin;
432 }
433};
434
435#endif
Note: See TracBrowser for help on using the repository browser.