source: branches/Mars_MC/mcore/huffman.h@ 17893

Last change on this file since 17893 was 16442, checked in by tbretz, 12 years ago
If exceptions are not enabled, don't throw...
File size: 10.9 KB
Line 
1#ifndef FACT_huffman
2#define FACT_huffman
3
4#include <string.h>
5#include <stdint.h>
6
7#include <set>
8#include <string>
9#include <vector>
10
11#define MAX_SYMBOLS (1<<16)
12
13// ================================================================
14
15namespace Huffman
16{
17 static unsigned long numbytes_from_numbits(unsigned long numbits)
18 {
19 return numbits / 8 + (numbits % 8 ? 1 : 0);
20 }
21
22 struct TreeNode
23 {
24 TreeNode *parent;
25 union
26 {
27 struct
28 {
29 TreeNode *zero, *one;
30 };
31 uint16_t symbol;
32 };
33
34 size_t count;
35 bool isLeaf;
36
37 TreeNode(uint16_t sym, size_t cnt=0) : parent(0), isLeaf(true)
38 {
39 symbol = sym;
40 count = cnt;
41 }
42
43 TreeNode(TreeNode *n0=0, TreeNode *n1=0) : parent(0), isLeaf(false)
44 {
45 count = n0 && n1 ? n0->count + n1->count : 0;
46 zero = n0 && n1 ? (n0->count > n1->count ? n0 : n1) : NULL;
47 one = n0 && n1 ? (n0->count > n1->count ? n1 : n0) : NULL;
48
49 if (n0)
50 n0->parent = this;
51
52 if (n1)
53 n1->parent = this;
54 }
55
56 ~TreeNode()
57 {
58 if (isLeaf)
59 return;
60
61 if (zero)
62 delete zero;
63 if (one)
64 delete one;
65 }
66
67 bool operator() (const TreeNode *hn1, const TreeNode *hn2) const
68 {
69 return hn1->count < hn2->count;
70 }
71 };
72
73
74 struct Encoder
75 {
76 struct Code
77 {
78 size_t bits;
79 uint8_t numbits;
80
81 Code() : numbits(0) { }
82 };
83
84 size_t count;
85 Code lut[1<<16];
86
87 bool CreateEncoder(const TreeNode *n, size_t bits=0, uint8_t nbits=0)
88 {
89 if (n->isLeaf)
90 {
91#ifdef __EXCEPTIONS
92 if (nbits>sizeof(size_t)*8)
93 throw std::runtime_error("Too many different symbols - this should not happen!");
94#else
95 if (nbits>sizeof(size_t)*8)
96 {
97 count = 0;
98 return false;
99 }
100#endif
101 lut[n->symbol].bits = bits;
102 lut[n->symbol].numbits = nbits==0 ? 1 : nbits;
103 count++;
104 return true;
105 }
106
107 return
108 CreateEncoder(n->zero, bits, nbits+1) &&
109 CreateEncoder(n->one, bits | (1<<nbits), nbits+1);
110
111 }
112
113 void WriteCodeTable(std::string &out) const
114 {
115 out.append((char*)&count, sizeof(size_t));
116
117 for (uint32_t i=0; i<MAX_SYMBOLS; i++)
118 {
119 const Code &n = lut[i];
120 if (n.numbits==0)
121 continue;
122
123 // Write the 2 byte symbol.
124 out.append((char*)&i, sizeof(uint16_t));
125 if (count==1)
126 return;
127
128 // Write the 1 byte code bit length.
129 out.append((char*)&n.numbits, sizeof(uint8_t));
130
131 // Write the code bytes.
132 uint32_t numbytes = numbytes_from_numbits(n.numbits);
133 out.append((char*)&n.bits, numbytes);
134 }
135 }
136
137 void Encode(std::string &out, const uint16_t *bufin, size_t bufinlen) const
138 {
139 if (count==1)
140 return;
141
142 uint8_t curbyte = 0;
143 uint8_t curbit = 0;
144
145 for (uint32_t i=0; i<bufinlen; ++i)
146 {
147 const uint16_t &symbol = bufin[i];
148
149 const Code *code = lut+symbol;
150
151 uint8_t nbits = code->numbits;
152 const uint8_t *bits = (uint8_t*)&code->bits;
153
154 while (nbits>0)
155 {
156 // Number of bits available in the current byte
157 const uint8_t free_bits = 8 - curbit;
158
159 // Write bits to current byte
160 curbyte |= *bits<<curbit;
161
162 // If the byte has been filled, put it into the output buffer
163 // If the bits exceed the current byte step to the next byte
164 // and fill it properly
165 if (nbits>=free_bits)
166 {
167 out += curbyte;
168 curbyte = *bits>>free_bits;
169
170 bits++;
171 }
172
173 // Adapt the number of available bits, the number of consumed bits
174 // and the bit-pointer accordingly
175 const uint8_t consumed = nbits>8 ? 8 : nbits;
176 nbits -= consumed;
177 curbit += consumed;
178 curbit %= 8;
179 }
180 }
181
182 // If the buffer-byte is half-full, also add it to the output buffer
183 if (curbit>0)
184 out += curbyte;
185 }
186
187 Encoder(const uint16_t *bufin, size_t bufinlen) : count(0)
188 {
189 uint64_t counts[MAX_SYMBOLS];
190 memset(counts, 0, sizeof(uint64_t)*MAX_SYMBOLS);
191
192 // Count occurances
193 for (const uint16_t *p=bufin; p<bufin+bufinlen; p++)
194 counts[*p]++;
195
196 // Copy all occuring symbols into a sorted list
197 std::multiset<TreeNode*, TreeNode> set;
198 for (int i=0; i<MAX_SYMBOLS; i++)
199 if (counts[i])
200 set.insert(new TreeNode(i, counts[i]));
201
202 // Create the tree bottom-up
203 while (set.size()>1)
204 {
205 auto it = set.begin();
206
207 auto it1 = it++;
208 auto it2 = it;
209
210 TreeNode *nn = new TreeNode(*it1, *it2);
211
212 set.erase(it1, ++it2);
213
214 set.insert(nn);
215 }
216
217 // get the root of the tree
218 const TreeNode *root = *set.begin();
219
220 CreateEncoder(root);
221
222 // This will delete the whole tree
223 delete root;
224 }
225
226 };
227
228
229
230 struct Decoder
231 {
232 uint16_t symbol;
233 uint8_t nbits;
234 bool isLeaf;
235
236 Decoder *lut;
237
238 Decoder() : isLeaf(false), lut(NULL)
239 {
240 }
241
242 ~Decoder()
243 {
244 if (lut)
245 delete [] lut;
246 }
247
248 void Set(uint16_t sym, uint8_t n=0, size_t bits=0)
249 {
250 if (!lut)
251 lut = new Decoder[256];
252
253 if (n>8)
254 {
255 lut[bits&0xff].Set(sym, n-8, bits>>8);
256 return;
257 }
258
259 const int nn = 1<<(8-n);
260
261 for (int i=0; i<nn; i++)
262 {
263 const uint8_t key = bits | (i<<n);
264
265 lut[key].symbol = sym;
266 lut[key].isLeaf = true;
267 lut[key].nbits = n;
268 }
269 }
270
271 void Build(const TreeNode &p, uint64_t bits=0, uint8_t n=0)
272 {
273 if (p.isLeaf)
274 {
275 Set(p.symbol, n, bits);
276 return;
277 }
278
279 Build(*p.zero, bits, n+1);
280 Build(*p.one, bits | (1<<n), n+1);
281 }
282
283 Decoder(const TreeNode &p) : symbol(0), isLeaf(false), lut(NULL)
284 {
285 Build(p);
286 }
287
288 const uint8_t *Decode(const uint8_t *in_ptr, const uint8_t *in_end,
289 uint16_t *out_ptr, const uint16_t *out_end) const
290 {
291 Decoder const *p = this;
292
293 if (in_ptr==in_end)
294 {
295 while (out_ptr < out_end)
296 *out_ptr++ = p->lut->symbol;
297 return in_ptr;
298 }
299
300 uint8_t curbit = 0;
301 while (in_ptr<in_end && out_ptr<out_end)
302 {
303 const uint16_t *two = (uint16_t*)in_ptr;
304
305 const uint8_t curbyte = (*two >> curbit);
306
307#ifdef __EXCEPTIONS
308 if (!p->lut)
309 throw std::runtime_error("Unknown bitcode in stream!");
310#else
311 if (!p->lut)
312 return NULL;
313#endif
314
315 p = p->lut + curbyte;
316 if (!p->isLeaf)
317 {
318 in_ptr++;
319 continue;
320 }
321
322 *out_ptr++ = p->symbol;
323 curbit += p->nbits;
324
325 p = this;
326
327 if (curbit>=8)
328 {
329 curbit %= 8;
330 in_ptr++;
331 }
332
333 }
334
335 return curbit ? in_ptr+1 : in_ptr;
336 }
337
338 Decoder(const uint8_t* bufin, int64_t &pindex) : isLeaf(false), lut(NULL)
339 {
340 // FIXME: Sanity check for size missing....
341
342 // Read the number of entries.
343 size_t count=0;
344 memcpy(&count, bufin + pindex, sizeof(count));
345 pindex += sizeof(count);
346
347 // Read the entries.
348 for (size_t i=0; i<count; i++)
349 {
350 uint16_t sym;
351 memcpy(&sym, bufin + pindex, sizeof(uint16_t));
352 pindex += sizeof(uint16_t);
353
354 if (count==1)
355 {
356 Set(sym);
357 break;
358 }
359
360 uint8_t numbits;
361 memcpy(&numbits, bufin + pindex, sizeof(uint8_t));
362 pindex += sizeof(uint8_t);
363
364 const uint8_t numbytes = numbytes_from_numbits(numbits);
365
366#ifdef __EXCEPTIONS
367 if (numbytes>sizeof(size_t))
368 throw std::runtime_error("Number of bytes for a single symbol exceeds maximum.");
369#else
370 if (numbytes>sizeof(size_t))
371 {
372 pindex = -1;
373 return;
374 }
375#endif
376 size_t bits=0;
377 memcpy(&bits, bufin+pindex, numbytes);
378 pindex += numbytes;
379
380 Set(sym, numbits, bits);
381 }
382 }
383 };
384
385 bool Encode(std::string &bufout, const uint16_t *bufin, size_t bufinlen)
386 {
387 const Encoder encoder(bufin, bufinlen);
388
389#ifndef __EXCEPTIONS
390 if (encoder.count==0)
391 return false;
392#endif
393
394 bufout.append((char*)&bufinlen, sizeof(size_t));
395 encoder.WriteCodeTable(bufout);
396 encoder.Encode(bufout, bufin, bufinlen);
397
398 return true;
399 }
400
401 int64_t Decode(const uint8_t *bufin,
402 size_t bufinlen,
403 std::vector<uint16_t> &pbufout)
404 {
405 int64_t i = 0;
406
407 // Read the number of data bytes this encoding represents.
408 size_t data_count = 0;
409 memcpy(&data_count, bufin, sizeof(size_t));
410 i += sizeof(size_t);
411
412
413 pbufout.resize(data_count);
414
415 const Decoder decoder(bufin, i);
416
417#ifndef __EXCEPTIONS
418 if (i==-1)
419 return -1;
420#endif
421
422 const uint8_t *in_ptr =
423 decoder.Decode(bufin+i, bufin+bufinlen,
424 pbufout.data(), pbufout.data()+data_count);
425
426#ifndef __EXCEPTIONS
427 if (!in_ptr)
428 return -1;
429#endif
430
431 return in_ptr-bufin;
432 }
433};
434
435#endif
Note: See TracBrowser for help on using the repository browser.