#include <kernel/arch/generic.h>
#include <kernel/malloc.h>
#include <kernel/panic.h>
#include <kernel/util.h>
#include <shared/mem.h>
#include <stdbool.h>
#include <stdint.h>

#define MALLOC_MAGIC 0x616c6c6f63686472
#define DESCLEN 8

typedef struct Allocation Allocation;
struct Allocation {
	uint64_t magic;
	uint64_t len;
	Allocation *next, *prev;
	void *stacktrace[4];
	char desc[DESCLEN];
};

Allocation *malloc_last = NULL;

extern uint8_t pbitmap[]; /* linker.ld */
static size_t pbitmap_len; /* in bytes */
static size_t pbitmap_firstfree = 0;


static
bool bitmap_get(long i)
{
	assert(i >= 0);
	size_t b = i / 8;
	uint8_t m = 1 << (i&7);
	assert(b < pbitmap_len);
	return (pbitmap[b]&m) != 0;
}

static
bool bitmap_set(long i, bool v)
{
	assert(i >= 0);
	size_t b = i / 8;
	uint8_t m = 1 << (i&7);
	assert(b < pbitmap_len);
	bool prev = (pbitmap[b]&m) != 0;
	if (v) pbitmap[b] |=  m;
	else   pbitmap[b] &= ~m;
	return prev;
}

static
long toindex(void *p)
{
	return ((long)p - (long)pbitmap) / PAGE_SIZE;
}

static
size_t page_amt(size_t bytes)
{
	return (bytes + PAGE_SIZE - 1) / PAGE_SIZE;
}

void
mem_init(void *memtop)
{
	kprintf("memory   %8x -> %8x\n", &_bss_end, memtop);
	pbitmap_len = toindex(memtop) / 8;
	memset(pbitmap, 0, pbitmap_len);
	mem_reserve(pbitmap, pbitmap_len);
}

void
mem_reserve(void *addr, size_t len)
{
	kprintf("reserved %8x -> %8x\n", addr, addr + len);

	/* align to the previous page */
	size_t off = (uintptr_t)addr & PAGE_MASK;
	addr -= off;
	len += off;
	size_t first = toindex(addr);
	for (size_t i = 0; i * PAGE_SIZE < len; i++) {
		if ((first + i) / 8 >= pbitmap_len)
			break;
		bitmap_set(first + i, true);
	}
}

void
mem_debugprint(void)
{
	size_t count = 0, bytes = 0, pages = 0;
	kprintf("[kern] current allocations:\n");
	for (Allocation *iter = malloc_last; iter; iter = iter->prev) {
		kprintf(
			"%08p %6dB %.8s ",
			((void*)iter) + sizeof(Allocation),
			iter->len - sizeof(Allocation),
			iter->desc
		);
		for (size_t i = 0; i < 4; i++) {
			kprintf(" k/%08x", iter->stacktrace[i]);
		}
		kprintf("\n");

		count++;
		bytes += iter->len;
		pages += page_amt(iter->len);
	}
	kprintf(
		"%d in total, %d bytes, %d pages = %dB used\n",
		count, bytes, pages, pages*PAGE_SIZE
	);
}

void
*page_alloc(size_t pages)
{
	/* i do realize how painfully slow this is */
	size_t streak = 0;
	for (size_t i = pbitmap_firstfree; i < pbitmap_len * 8; i++) {
		if (bitmap_get(i)) {
			streak = 0;
			continue;
		}
		if (++streak >= pages) {
			/* found hole big enough for this allocation */
			i = i + 1 - streak;
			for (size_t j = 0; j < streak; j++)
				bitmap_set(i + j, true);
			pbitmap_firstfree = i + streak - 1;
			return pbitmap + i * PAGE_SIZE;
		}
	}
	kprintf("we ran out of memory :(\ngoodbye.\n");
	panic_unimplemented();
}

void
*page_zalloc(size_t pages)
{
	void *p = page_alloc(pages);
	memset(p, 0, pages * PAGE_SIZE);
	return p;
}

/* frees `pages` consecutive pages starting from *first */
void
page_free(void *addr, size_t pages)
{
	assert(addr >= (void*)pbitmap);
	size_t first = toindex(addr);
	for (size_t i = 0; i < pages; i++) {
		if (bitmap_set(first + i, false) == false) {
			panic_invalid_state();
		}
	}
	if (pbitmap_firstfree > first) {
		pbitmap_firstfree = first;
	}
}

void
kmalloc_sanity(const void *addr)
{
	assert(addr);
	const Allocation *hdr = addr - sizeof(Allocation);
	assert(hdr->magic == MALLOC_MAGIC);
	if (hdr->next) assert(hdr->next->prev == hdr);
	if (hdr->prev) assert(hdr->prev->next == hdr);
}

// TODO better kmalloc
void
*kmalloc(size_t len, const char *desc)
{
	Allocation *hdr;
	void *addr;

	len += sizeof(Allocation);
	hdr = page_alloc(page_amt(len));
	hdr->magic = MALLOC_MAGIC;
	hdr->len = len;

	memset(hdr->desc, ' ', DESCLEN);
	if (desc) {
		for (int i = 0; i < DESCLEN; i++) {
			if (desc[i] == '\0') break;
			hdr->desc[i] = desc[i];
		}
	}

	hdr->next = NULL;
	hdr->prev = malloc_last;
	if (hdr->prev) {
		assert(!hdr->prev->next);
		hdr->prev->next = hdr;
	}

	for (size_t i = 0; i < 4; i++)
		hdr->stacktrace[i] = debug_caller(i);

	malloc_last = hdr;

	addr = (void*)hdr + sizeof(Allocation);
#ifndef NDEBUG
	memset(addr, 0xCC, len);
#endif
	kmalloc_sanity(addr);
	return addr;
}

void
kfree(void *ptr)
{
	Allocation *hdr;
	size_t pages;
	if (ptr == NULL) return;

	hdr = ptr - sizeof(Allocation);
	kmalloc_sanity(ptr);
	pages = page_amt(hdr->len);

	hdr->magic = ~MALLOC_MAGIC; // (hopefully) detect double frees
	if (hdr->next)
		hdr->next->prev = hdr->prev;
	if (hdr->prev)
		hdr->prev->next = hdr->next;
	if (malloc_last == hdr)
		malloc_last = hdr->prev;
#ifndef NDEBUG
	memset(hdr, 0xC0, pages * PAGE_SIZE);
#endif
	page_free(hdr, pages);
}