/* bs128.c */
/*
 * Copyright (c) 2016 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>

#include "bs128.h"

bs128_t *bs_alloc()
{
	return (bs128_t *) malloc(sizeof(bs128_t));
}

bs128_t *bs_new(int32_t bit)
{
	bs128_t *ret;

	assert(bit <= 128);

	if ((ret = bs_alloc()) == NULL) {
		return NULL;
	}

	ret->len = bit;
	memset(ret->buf, 0, BS128_BYTE_LEN);

	return (ret);
}

void bs_free(bs128_t *a)
{
	a->len = 0;
	free(a);
}

int32_t bs_get_len(bs128_t *bs)
{
	return bs->len;
}


void bs_set_zero(bs128_t *bs, int32_t begin, int32_t len)
{
	int32_t end;
	int32_t i, j;
	uint8_t hi = 0xfe;
	uint8_t lo = 0x7f;

	assert(begin < 128);
	assert(begin + len <= 128);

	end = begin + len - 1;

	i = begin >> 3;
	j = end >> 3;
	hi <<= ((begin & 0x07) ^ 0x07);
	lo >>= (end & 0x07);
	if (i == j) {
		bs->buf[i] &= (hi | lo);
	} else {
		bs->buf[i] &= hi;
		for (i++; i < j; i++) {
			bs->buf[i] = 0;
		}
		bs->buf[j] &= lo;
	}
}

void bs_get_data(bs128_t *bs, uint8_t *out, int32_t len)
{
	int32_t lsb;
	int32_t i, j;
	int32_t n;
	int8_t mask = 0x80;
	uint8_t buf[BS128_BYTE_LEN];

	memset(buf, 0, BS128_BYTE_LEN);

	lsb = len - 1;
	j = lsb >> 3;
	mask >>= (lsb & 0x07);

	memcpy(buf, bs->buf, j);
	buf[j] |= bs->buf[j] & mask;

	memcpy(out, buf, BS128_BYTE_LEN);
}

void bs_set_data(bs128_t *bs, uint8_t *in, int32_t len)
{
	int32_t lsb;
	int32_t j;
	uint8_t mask = 0x7f;
	uint8_t w;

	assert(len > 0 && len <= 128);

	lsb = len - 1;
	j = lsb >> 3;
	mask >>= (lsb & 0x07);

	memcpy(bs->buf, in, j);

	w = in[j] & ~mask; /* Leaving only the bits that needs to copy. */
	bs->buf[j] &= mask; /* Turn off the bits to override. */
	bs->buf[j] |= w;
}

void bs_copy(bs128_t *from, bs128_t *to)
{
	memcpy(to, from, sizeof(bs128_t));
}

bs128_t *bs_clone(bs128_t *a)
{
	bs128_t *ret;

	if ((ret = bs_new(a->len)) == NULL)
		return NULL;

	memcpy(ret->buf, a->buf, sizeof(a->buf));

	return (ret);
}

int bs_check_bit(bs128_t *a, int32_t bit)
{
	uint8_t s;

	assert(bit >= 0 && bit < 128);

	/* ex: 0x7f => 0x0f, 0x08 => 0x01 */
	s = a->buf[bit >> 3];
	return (s & (1 << ((bit & 0x07) ^ 0x07)));
}

int bs_xor(bs128_t *a, bs128_t *b)
{
	int i;

	if (a->len != b->len) {
		fprintf(stderr, "size mismatch: a=%d b=%d\n", a->len, b->len);
		return -1;
	}

	for (i = 0; i < BS128_BYTE_LEN; i++) {
		a->buf[i] ^= b->buf[i];
	}

	return 0;
}

/**
 * bit right shift.
 *
 * 0xb8     0x3b         0x3b     0x78
 * [0]      [1]      ... [14]     [15]
 * 10111000 00111011 ... 00111011 01111000
 *
 * bs_rshift(a, 1)
 * 0x4c     0x1d         0x?d     0xbc
 * [0]      [1]      ... [14]     [15]
 * 01011100 00011101 ... ?0011101 10111100
 */
void bs_rshift(bs128_t *a, int32_t bit)
{
	uint8_t *w;
	int i;
	int k = bit;

	assert(bit > 0 && bit < 8);

	w = a->buf;

	/*
	 * <http://stackoverflow.com/questions/5996384/bitwise-shift-operation-on-a-128-bit-number>
	 */
	for (i = BS128_BYTE_LEN - 1; i > 0; i--) {
		w[i] = (w[i-1] << (8-k)) | (w[i] >> k);
	}
        w[0] = (w[0] >> k);
}

void bs_print(bs128_t *a)
{
	int i;

	if (a == NULL) {
		printf("NULL\n");
	} else {
		printf("0x");
		for (i = 0; i < BS128_BYTE_LEN; i++) {
			printf("%.2x", a->buf[i]);
		}
		printf(" len: %d bit\n", a->len);
	}
}
