#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#include "jalloc.h"

/*
** Each byte is grouped into blocks of size block_size 
** Each bit in the bitmap represents a block (0 = Free)
*/


typedef struct
{
	unsigned char	bits;
	unsigned char	run[8];	
	unsigned char	max_run;
	unsigned char	max_run_bit;
	unsigned char	tail_run_bit;	// log base 2
} bit_lookupT;

bit_lookupT	bit_lookup[256];
heapT	*current_heap = NULL;

char	*jstrchr (char *s, int c)
{
	while (*s)
	{
		if (*s == c) return s;
		s++;
	}

	return NULL;
}

void	*jemcpy (void *dest, const void *src, size_t n)
{
	size_t	i;
	void	*d;

	d = dest;

	for (i=0; i<n; i++)
	{
		*(char *)dest++ = *(char *)src++;
	}

	return (d);
}

void	*jemset (void *dest, int c, size_t n)
{
	size_t	i;
	void	*d;

	d = dest;

	for (i=0; i<n; i++)
	{
		*(char *)dest++ = c;
	}

	return (d);
}


size_t	jstrlen (const char *s)
{
	size_t	i=0;

	while (*s++) i++;

	return (i);
}

int	jstrcmp (const char *s1, const char *s2)
{
	while (*s1 && *s2)
	{
		if (*s1 < *s2) return -1; else	
		if (*s1 > *s2) return  1;

		s1++;
		s2++;
	}

	if (*s1 < *s2) return -1; else	
	if (*s1 > *s2) return  1;
	
	return (0);
}

char	*jstrdup (const char *s)
{
	size_t	len;
	char	*ret, *r;

	len = jstrlen (s);	
	ret = (char *) jalloc (len + 1);

	if (!ret)
	{
		return NULL;
	}

	r = ret;

	while (*s) *ret++ = *s++;
	*ret = '\0';

	return r;
}

void	set_arena (heapT *h)
{
	current_heap = h;
}

void	show_byte (unsigned char c)
{
	int	j;

	for (j=0; j<8; j++)
	{
		printf ("%d", (c >> j) & 0x01);
	}
}

void	show_bitmap (void)
{
	int	i;

	for (i=0; i<current_heap->blocks/8; i++)
	{
		show_byte (current_heap->bit_map[i]);
		printf (".");
	}

	printf ("\n");
}

void	calc_lookup (void)
{
	int		i, b, c;
	char	bit[8];

	for (i=0; i<256; i++)
	{
//		printf ("%3d ", i);

		bit_lookup[i].bits = 0;
		for (b=0; b<8; b++)		
		{
			if (i & (1 << b))
			{
				bit[b] = 1;
			//	printf ("1");
				bit_lookup[i].bits++;
			} else
			{
			//	printf ("0");
				bit[b] = 0;
			}
		}
		
		bit_lookup[i].max_run = 0;
		bit_lookup[i].max_run_bit = 0;

	//	printf (" ");

		for (c=0; c<8; c++)
		{
			b = c;
			while ((b < 8) && !bit[b]) b++;
			bit_lookup[i].run[c] = b - c;
			//printf ("%1d", b - c);
			if (b - c >= bit_lookup[i].max_run)
			{
				bit_lookup[i].max_run = b - c;
				bit_lookup[i].max_run_bit = c;
			}
		}

		c = 7;
		while ( (c >= 0) && !bit[c] ) c--;
		if (c < 0)
		{
			bit_lookup[i].tail_run_bit = 0;
		} else
		{
			bit_lookup[i].tail_run_bit = c+1;
		}

		//printf (" trb=%d, mrb=%d, mr=%d\n", bit_lookup[i].tail_run_bit, bit_lookup[i].max_run_bit, bit_lookup[i].max_run);
	}
}

heapT	*new_heap (long size, int block_size, void *base_ptr)
{
	heapT	*h;
	long	blocks;
	int		i;

	h = malloc (sizeof (heapT));
	if (!h) return NULL;

	h->size = size;
	h->block_size = block_size;

	h->heap = malloc (size);
	if (!h->heap)
	{
		free (h);
		return NULL;
	}

	blocks = size / block_size;	
	h->blocks = blocks * 7 / 8; // leave room for bitmap

	h->bit_map = h->heap;
	h->heap += blocks / 8;

	for (i=0; i<blocks/8; i++)
	{
		h->bit_map[i] = 0x00;	
	}

	h->first_free_block_idx = 0;

	current_heap = h;

	calc_lookup();

	return (h);
}

long	ptr_to_block_num (void *p)
{
	int	n;

	if (p < current_heap->heap)
	{
		fprintf (stderr, "Asking for pointer below heap start\n");
		exit (-1);
	}
	n = (long)(p - current_heap->heap) / current_heap->block_size;

	if (n >= current_heap->blocks)
	{
		fprintf (stderr, "Asking for pointer above heap end\n");
		exit (-1);
	}

	return n;
}

void	jree (void *p)
{
	unsigned int	 blocks, i, j, s, e, start_block;	

	if (!p) return;

	p -= sizeof (unsigned int);
	
	blocks = *(unsigned int *)p;

	// TODO - WOULD BE GOOD TO HAVE A CHECKSUM IN THE BLOCK COUNT
	if (!blocks) 
	{
		return;
	}

	*(unsigned int *)p = 0;

	start_block = ptr_to_block_num (p);
	s = floor(start_block / 8.0);
	e = floor((start_block + blocks) / 8.0);

	if (s < current_heap->first_free_block_idx)
	{
		current_heap->first_free_block_idx = s;
	}

	// find the bit start in this area
	i = start_block % 8;

	if (s != e)
	{
		for (j=i; j<8; j++)
		{
			current_heap->bit_map[s] -= 1 << j;
		}
	
		for (i=s+1; i<e; i++)
		{
			current_heap->bit_map[i] = 0x00;
		}
	
		i = (start_block + blocks) % 8;
	
		for (j=0; j<i; j++)
		{
			current_heap->bit_map[e] -= 1 << j;
		}
	} else
	{
		for (j=start_block % 8; j<(start_block + blocks) % 8; j++)
		{
			current_heap->bit_map[s] -= 1 << j;
		}
	}

}

long	mem_used (void)
{
	int		i;
	long	used = 0;

	for (i=0; i<current_heap->blocks / 8; i++)
	{
		used += bit_lookup[(int)current_heap->bit_map[i]].bits;
	}

	return (used * current_heap->block_size);
}

void	*jalloc (size_t size)
{
	int		i, j, k,  e, start_block;
	unsigned int		blocks;
	int		b;
	void	*ret;
	int		loop;

	if (!size)
	{
		return NULL;
	}

	size += current_heap->block_size; // To fit the length info
	blocks = ceil (size / (float) current_heap->block_size);

	for (loop=0; loop<2; loop++)
	{

	i = current_heap->first_free_block_idx;

	if (i >= current_heap->blocks/8)
	{
		current_heap->first_free_block_idx = 0;
		i = 0;
	}

	if (blocks <= 8)
	{
		// This will fit in at most 2 bitmap areas
		while (i < current_heap->blocks / 8)
		{
			// See if it fits inside this area
			if (bit_lookup[(int)current_heap->bit_map[i]].max_run >= blocks)
			{
				// We have found a bitmap area that contains a run of free
				// blocks that will fit this request
				b = bit_lookup[(int)current_heap->bit_map[i]].max_run_bit;

				start_block = i * 8 + b;

				ret = current_heap->heap + start_block*current_heap->block_size;
				jemcpy (ret, &blocks, sizeof (int));
				ret += current_heap->block_size;	

				// Mark the used blocks
			
				for (j=b; j<b+blocks; j++)
				{
					current_heap->bit_map[i] += 1 << j;
				}
				if (b+blocks == 8) current_heap->first_free_block_idx++;
				return ret;
			} else
			if (i < (current_heap->blocks/8) - 1)
			{
				// see if it can span across area boundaries	
				b = (8 - bit_lookup[(int)current_heap->bit_map[i]].tail_run_bit) +
					bit_lookup[(int)current_heap->bit_map[i+1]].run[0];

				if (b >= blocks)
				{
					// We have found a place
					start_block = i * 8 + bit_lookup[(int)current_heap->bit_map[i]].tail_run_bit;

					ret = current_heap->heap + start_block*current_heap->block_size;
					jemcpy (ret, &blocks, sizeof (int));
					ret += current_heap->block_size;	
					current_heap->first_free_block_idx++;
				
					// Mark the used blocks
				
					k = bit_lookup[(int)current_heap->bit_map[i]].tail_run_bit;

					for (j=k; j<8; j++)
					{
						current_heap->bit_map[i] += 1 << j;
					}
					for (j=0; j<blocks-(8-k); j++)
					{
						current_heap->bit_map[i+1] += 1 << j;
					}

					return ret;
				}
			}
			i ++;
		}
	} else
	{
		// Needs to span mutliple bitmap areas

		while (i < current_heap->blocks / 8)
		{
			// # of blocks available at the end of this area
			b = 8 - bit_lookup[(int)current_heap->bit_map[i]].tail_run_bit;
			if (!b)
			{
				// No room in this area, skip
				i++;
				continue;
			}
			j = i;
			while ( (b < blocks) && (j < current_heap->blocks/8) )
			{
				j ++;
				if (!current_heap->bit_map[j])
				{
					b += 8;
				} else
				{
					b += bit_lookup[(int)current_heap->bit_map[j]].run[0];
					break;
				}
			}

			if (b >= blocks)
			{
				// Found space

				start_block = i * 8 + bit_lookup[(int)current_heap->bit_map[i]].tail_run_bit;

				ret = current_heap->heap + start_block*current_heap->block_size;
				jemcpy (ret, &blocks, sizeof (int));
				ret += current_heap->block_size;	
				current_heap->first_free_block_idx += blocks / 8;

				e = j;
				// Number of blocks used at the start area + the middle areas
				k = (e - (i+1))*8 + (8 - bit_lookup[(int)current_heap->bit_map[i]].tail_run_bit);

				// Mark the start
				for (j=bit_lookup[(int)current_heap->bit_map[i]].tail_run_bit; j<8; j++)
				{
					current_heap->bit_map[i] += 1 << j;
				}


				// Mark the beginning of area e
				for (j=0; j<blocks-k; j++)
				{
					current_heap->bit_map[e] += 1 << j;
				}

				// Mark the middle
				for (j=i+1; j<e; j++)
				{
					current_heap->bit_map[j] = 0xff;	
				}

				return ret;
			}
	
			i++;
		}
	}
	}

	return NULL;	
}

#define TEST_SIZE 200

int test_jalloc (void* (*alloc)(size_t), void (*ree)(void *))
{
	heapT	*h;
	int		i,j,r;
	int		testd[TEST_SIZE];
	int		testz[TEST_SIZE];
	int		*test[TEST_SIZE];
	long	requested;

	h = new_heap (1024*1024*320, 4);
	if (!h)
	{
		fprintf (stderr, "Failed to make a new heap\n");
		return (-1);
	}

	for (i=0; i<TEST_SIZE; i++)
	{
		test[i] = NULL;
	}
	
	
	for (r=0; r<10000; r++)
	{
		requested = 0;
		for (i=0; i<TEST_SIZE; i++)
		if ((i%3 == r%3) && !test[i])
		{
			testd[i] = random();
			testz[i] = (random() % 1024) + 1;
			test[i] = (int *) alloc (sizeof(int) * testz[i]);	
			requested += sizeof (int) * testz[i];
			if (!test[i])
			{
				printf ("[%d] FAILED TO ALLOC %d bytes TOTAL %ld requested, %ld used\n", i, sizeof(int) * testz[i], requested, mem_used());
				return (-2);
			}
			for (j=0; j<testz[i]; j++)
			{
				test[i][j] = (testd[i] ^ i);
			}
		}
	
		for (i=0; i<TEST_SIZE; i++)
		if ((i%3 == r%3) && !test[i])
		{
			j = 0;
			while ((j < testz[i]) && (test[i][j] == (testd[i] ^ i))) j++;
			if (j < testz[i])
			{
				printf ("FAILED ON ITEM %d[%d/%d] (d=%d, found %d)\n", i,j,testz[i], testd[i], test[i][j]);
				return (-3);
			}
		}
	
		for (i=0; i<TEST_SIZE; i++)
		if ((i+1)%17 == r%17)
		{
			ree (test[i]);
			test[i] = NULL;
		}
	
		printf (".");	
		fflush (stdout);
	}

	for (i=0; i<TEST_SIZE; i++)
	{
		ree (test[i]);
	}

	if (mem_used())
	{
		printf ("Leaked Bytes = %ld\n", mem_used());
		return (-4);
	}
	
	printf ("\nTest Passed OK!\n");
	return (0);
}

int		save_heap (heapT *h, char *name)
{
	FILE	*fp;

	fp = fopen (name, "w");
	if (!fp)
	{
		fprintf (stderr, "Failed to write to heap file '%s'\n", name);
		return 1;
	}

	fwrite (h, sizeof (heapT), 1, fp);
	fwrite (h->heap, 1, h->size, fp);
	fwrite (h->bit_map, 1, h->blocks / 8, fp);

	fclose (fp);

	return 0;
}

heapT	*load_heap (char *name)
{	
	FILE	*fp;
	heapT	*h;
	heapT	H;

	fp = fopen (name, "r");
	if (!fp)
	{
		fprintf (stderr, "Failed to read from heap file '%s'\n", name);
		return NULL;
	}

	fread (&H, sizeof (heapT), 1, fp);

	h = new_heap (H.size, H.block_size);

	fread (h->heap, 1, h->size, fp);
	fread (h->bit_map, 1, h->blocks / 8, fp);

	h->used = H.used;
	h->first_free_block_idx = H.first_free_block_idx;

	fclose (fp);

	return h;
}
