#include "buffer.h"
#include <assert.h>
#include <stdlib.h>
// read
int buf_read_bits(Buffer *b, Word *word, Bits bits) {
assert(b);
if (!bits) return 1;
size_t remain = b->end - b->cur;
if ((remain < 16 && remain*8 - b->bits < bits) || bits > 32 || !word) return 0;
Word w;
Byte *c = b->cur;
Bits bb = b->bits;
if (bb) {
assert(bb < 8);
w = ( ((Code)c[0]) << (bb+24) );
if (remain > 4) {
w |= ( ((Code)c[1]) << (bb+16) );
w |= ( ((Code)c[2]) << (bb+8) );
w |= ( ((Code)c[3]) << bb );
w |= ( ((Code)c[4]) >> (8-bb) );
} else {
if (remain > 1) w |= ( ((Code)c[1]) << (bb+16) );
if (remain > 2) w |= ( ((Code)c[2]) << (bb+8) );
if (remain > 3) w |= ( ((Code)c[3]) << bb );
}
} else {
w = ( ((Code)c[0]) << 24 );
if (remain > 3) {
w |= ( ((Code)c[1]) << 16 );
w |= ( ((Code)c[2]) << 8 );
w |= ( ((Code)c[3]) );
} else {
if (remain > 1) w |= ( ((Code)c[1]) << 16 );
if (remain > 2) w |= ( ((Code)c[2]) << 8 );
}
}
bb += bits;
b->cur += bb >> 3;
b->bits = bb & 0x7;
*word = w >> (32 - bits);
return 1;
}
void buf_read_pad(Buffer *b) {
if (b->bits) {
assert(b->cur < b->end);
assert(b->bits < 8);
++b->cur;
b->bits = 0;
}
}
int buf_read_byte(Buffer *b, Byte *byte) {
if (b->bits) buf_read_pad(b);
if (b->cur >= b->end || !byte) return 0;
*byte = *b->cur++;
return 1;
}
// write
int buf_write_bits(Buffer *b, Word word, Bits bits) {
assert(b);
if (!bits) return 1;
if (b->end - b->cur < 5 || bits > 32) return 0;
word <<= 32 - bits;
Byte *c = b->cur;
Bits bb = b->bits;
if (bb) {
assert(bb < 8);
Bits bbi = 8 - bb;
c[0] = ((c[0] >> bbi) << bbi) | (word >> (bb+24));
c[1] = word >> (bb+16);
c[2] = word >> (bb+8);
c[3] = word >> bb;
c[4] = word << bbi;
} else {
c[0] = word >> 24;
c[1] = word >> 16;
c[2] = word >> 8;
c[3] = word;
}
bb += bits;
b->cur += bb >> 3;
b->bits = bb & 0x7;
return 1;
}
void buf_write_pad(Buffer *b) {
if (b->bits) {
assert(b->cur < b->end);
assert(b->bits < 8);
Bits bb = 8 - b->bits;
*b->cur = (*b->cur >> bb) << bb;
++b->cur;
b->bits = 0;
}
}
int buf_write_byte(Buffer *b, Byte byte) {
if (b->bits) buf_write_pad(b);
if (b->cur >= b->end) return 0;
*b->cur++ = byte;
return 1;
}