/*
This program is distributed under the terms of the 'MIT license'. The text
of this licence follows...

Copyright (c) 2007 J.D.Medhurst (a.k.a. Tixy)

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/

/**
@file

@brief String utilities.
*/

#include "common.h"
#include "stringf.h"



//
// StringFormatter::ConversionSpec
//

longlong StringFormatter::ConversionSpec::GetIntArg(unsigned flags)
	{
	va_list& args = Args;
	if(flags&LengthMod_ll)
		return va_arg(args,ulonglong); // long long argument

	if(flags&LengthMod_l)
		{
		// long argument...
		unsigned long int val = va_arg(args,unsigned long int);
		if(flags&SignedInt)
			return (signed long int)val;
		else
			return val;
		}

	// length is <= length of an int, these arguments are promoted to int when
	// passsed through variable length argument APIs, so we can get it's value now...
	unsigned val = va_arg(args,unsigned int);

	// and then cast it to the correct size...
	if(flags&LengthMod_h)
		{
		if(flags&SignedInt)
			return (signed short int)val;
		else
			return (unsigned short int)val;
		}
	else if(flags&LengthMod_hh)
		{
		if(flags&SignedInt)
			return (signed char)val;
		else
			return (unsigned char)val;
		}
	else
		{
		if(flags&SignedInt)
			return (signed int)val;
		else
			return val;
		}
	}


const char* StringFormatter::ConversionSpec::ReadInt(const char* format, int& val)
	{
	unsigned c = *format++;
	if(c=='*')
		{
		val = GetIntArg();
		return format;
		}

	// decode decimal number...
	unsigned v = 0;
	for(;;)
		{
		c -= '0';
		if(c>=10u)
			break; // not a decimal digit
		v = v*10+c;
		c = *format++;
		}
	--format;
	val = v;
	return format;
	}


const char* StringFormatter::ConversionSpec::Decode(const char* format)
	{
	char c;

	// get format flags...
	unsigned flags = 0;
	for(;;)
		{
		c = *format++;
		if(c=='-')
			flags |= FlagMinus;
		else if(c=='+')
			flags |= FlagPlus;
		else if(c==' ')
			flags |= FlagSpace;
		else if(c=='#')
			flags |= FlagHash;
		else if(c=='0')
			flags |= FlagZero;
		else
			break;
		}
	--format;

	// get field width...
	format = ReadInt(format,FieldWidth);
	if(FieldWidth<0)
		{
		FieldWidth = -FieldWidth;
		flags |= FlagMinus;
		}

	// get precision...
	Precision = -1;
	if(*format=='.')
		{
		++format;
		format = ReadInt(format,Precision);
		}

	// get length modifier...
	#define LENGTH_MODIFIER(type)						\
		(												\
		sizeof(type)==sizeof(char)?LengthMod_hh :		\
		sizeof(type)==sizeof(short)?LengthMod_h :		\
		sizeof(type)==sizeof(int)?0 :					\
		sizeof(type)==sizeof(long)?LengthMod_l :		\
		sizeof(type)==sizeof(longlong)?LengthMod_ll :	\
		LengthMod_unknown								\
		)

	c = *format++;
	if(c=='h')
		{
		if(*format!=c)
			flags |= LengthMod_h;
		else
			{
			flags |= LengthMod_hh;
			++format;
			}
		}
	else if(c=='l')
		{
		if(*format!=c)
			flags |= LengthMod_l;
		else
			{
			flags |= LengthMod_ll;
			++format;
			}
		}
	else if(c=='L')
		flags |= LengthMod_L;
	else if(c=='j')
		{
		flags |= LENGTH_MODIFIER(intmax_t);
		ASSERT_COMPILE(LENGTH_MODIFIER(intmax_t)!=LengthMod_unknown);
		}
	else if(c=='z')
		{
		flags |= LENGTH_MODIFIER(size_t);
		ASSERT_COMPILE(LENGTH_MODIFIER(size_t)!=LengthMod_unknown);
		}
	else if(c=='t')
		{
		flags |= LENGTH_MODIFIER(ptrdiff_t);
		ASSERT_COMPILE(LENGTH_MODIFIER(ptrdiff_t)!=LengthMod_unknown);
		}
	else
		--format; // restore format

	// overide FlagZero if we have FlagMinus
	if(flags&FlagMinus)
		flags &= ~FlagZero;
	Flags = flags;

	// finally, get conversion specifier...
	c = *format++;
	ConversionSpecifier = c;
	if(!c)
		--format; // avoid dropping null character from format string
	return format;
	}



//
// StringFormatter
//

char* StringFormatter::PushHex(ulonglong val, char* dst, int precision, int x, int prefix)
	{
	if(!val)
		{
		// value is zero...
		prefix = false;  // suppress prefix
		if(!precision)   // produce nothing if precision is zero...
			return dst;
		}

	if(precision<0)
		precision = 1; // -ve precission means produce as many digits as required, i.e. 1 or more.

	int alphaAdjust = x-'X'+7;
	do
		{
		unsigned c = val&0xf;
		if(c>=10)
			c += alphaAdjust;
		c += '0';
		*--dst = c;
		--precision;
		val >>= 4;
		}
	while(val);

	while(--precision>=0)
		*--dst = '0';

	if(prefix)
		{
		*--dst = x;
		*--dst = '0';
		}

	return dst;
	}


char* StringFormatter::PushOctal(ulonglong val, char* dst, int precision, int prefix)
	{
	if(!precision && !val) // produce nothing if precision is zero and val is zero
		return dst;

	if(precision<0)
		precision = 1; // -ve precission means produce as many digits as required, i.e. 1 or more.

	do
		{
		unsigned c = val&0x7;
		c += '0';
		*--dst = c;
		--precision;
		val >>= 3;
		}
	while(val);

	while(--precision>=0)
		*--dst = '0';

	if(prefix && *dst!='0')
		*--dst = '0';

	return dst;
	}


char* StringFormatter::PushDecimal(ulonglong val, char* dst, int precision)
	{
	if(!precision && !val) // produce nothing if precision is zero and val is zero
		return dst;

	if(precision<0)
		precision = 1; // -ve precission means produce as many digits as required, i.e. 1 or more.

	if(val<=0xffffffffu)
		{
		// value fits into 32 bits, so we can avoid 'long long' variables (and be faster)...
		unsigned a = (unsigned)val;
		while(a)
			{
			// calculate a/10 and a%10 using fixed point arithamic, see http://www.cs.uiowa.edu/~jones/bcd/divide.html ...

			// calculate q = a/10 by multiplying by 1/10 accurate to 35 binary places...
			unsigned q = a>>1;
			q += q>>1;
			q += q>>4;
			q += q>>8;
			q += q>>16;
			q  = q>>3;
			// q now equals either val/10 or val/10-1

			// calculate r = remainder of val/10
			unsigned r = a - q*10;
			if(r>=10)
				{
				// remainder too big (due to rounding error in approximate divide by 10) so adjust values...
				r -= 10;
				++q;
				}
			a = q; // a now equal a/10

			*--dst = r+'0'; // store single digit
			--precision;
			}
		}
	else
		{
		// value doesn't fit into 32 bits, so we have to use 'long long' variables...
		ulonglong a = val;
		do
			{
			// calculate a/10 and a%10 using fixed point arithamic, see http://www.cs.uiowa.edu/~jones/bcd/divide.html ...

			// calculate q = a/10 by multiplying by 1/10 accurate to 131 binary places...
			ulonglong q = a>>1;
			q += q>>1;
			q += q>>4;
			q += q>>8;
			q += q>>16;
			q += (q>>32);
			q += ((q>>32)>>32); // shift by 64 can generate warnings, so do shift in two steps, lets hope the compiler is smart enough to realise this is a nop when long long is 64 bits or less
			ASSERT_COMPILE(sizeof(ulonglong)<=sizeof(uint64_t)*2); // we've stopped at 128 bits, assert long long isn't bigger than this!
			q  = q>>3;
			// q now equals either val/10 or val/10-1

			// calculate r = remainder of val/10
			unsigned r = a - q*10;
			if(r>=10)
				{
				// remainder too big (due to rounding error in approximate divide by 10) so adjust values...
				r -= 10;
				++q;
				}
			a = q; // a now equal a/10

			*--dst = r+'0'; // store single digit
			--precision;
			}
		while(a);
		}

	while(--precision>=0)
		*--dst = '0';

	return dst;
	}



char* StringFormatter::DefaultUnkownFormat(char*& dstEnd, ConversionSpec& spec)
	{
	char c = spec.ConversionSpecifier;

	// check for floating point conversions...
	if(c>='A' && c<='Z')
		c += 'a'-'A';
	if(c=='a' || c=='e' || c=='f' || c=='g')
		{
		// ignore value because we don't support floating point...
		if(spec.Flags&LengthMod_L)
			va_arg(spec.Args,long double);
		else
			va_arg(spec.Args,double);
		return dstEnd;
		}

	// ignore unkown options...
	return dstEnd;
	}



#ifndef STRINGFORMATTER_UNKOWNFORMAT_DEFINED

char* StringFormatter::UnkownFormat(char*& dstEnd, ConversionSpec& format)
	{
	return DefaultUnkownFormat(dstEnd, format);
	}

#endif


size_t StringFormatter::VFormat(const char* formatString, va_list args)
	{
	ConversionSpec convertionSpec(args);

	size_t outSize = 0;
	const char* src = formatString;
	for(;;)
		{
		char c;

		// get characters to copy...
		const char* srcBase = src;
		do c = *src++;
		while(c!=0 && c!='%');

		// ouput everything up to character which stopped us...
		size_t size = src-srcBase-1;
		outSize += size;
		Out(srcBase,size);

		if(c==0)
			{
			// end of formatString...
			return outSize;
			}

		if(*src=='%')
			{
			// '%%' found, output '%' and continue...
			++outSize;
			Out(src++,1);
			continue;
			}

		// get format data...
		src = convertionSpec.Decode(src);

		// initialise data for conversion...
		char temp[FormatTextBufferSize];
		ASSERT_COMPILE(FormatTextBufferSize>sizeof(longlong)/2*5+1);     // check big enough for decimal number with sign character
		ASSERT_COMPILE(FormatTextBufferSize>(sizeof(longlong)*8+2)/3+1); // check big enough for octal number with leading zero
		ASSERT_COMPILE(FormatTextBufferSize>(sizeof(longlong)*2)+2);     // check big enough for hex number with leading '0x'
		char* stringEnd = temp+sizeof(temp);
		register char* string = stringEnd;
		size_t zeroPadPosition = 0;
		unsigned flags = convertionSpec.Flags;

		// do conversion...
		c = convertionSpec.ConversionSpecifier;
		if(c=='c')
			{
			*--string = (char)convertionSpec.GetIntArg();
			}
		else if(c=='p')
			{
			ulonglong val = (ulonglong)(uintptr_t)convertionSpec.GetPointerArg();
			string = PushHex(val,string,sizeof(void*)*2,'x',false);
			}
		else if(c=='s')
			{
			string = (char*)convertionSpec.GetPointerArg();
			stringEnd = string;
			int precision = convertionSpec.Precision;
			while(*stringEnd++ && precision--) {}
			--stringEnd;
			}
		else if(c=='n')
			{
			void* ptr = convertionSpec.GetPointerArg();
			     if(flags&LengthMod_ll)		*(longlong*)ptr = outSize;
			else if(flags&LengthMod_l)		*(long*)ptr = outSize;
			else if(flags&LengthMod_h)		*(short*)ptr = outSize;
			else if(flags&LengthMod_hh)		*(signed char*)ptr = outSize;
			else							*(int*)ptr = outSize;
			}
		else if(c!='d' && c!='i' && c!='o' && c!='u' && c!='x' && c!='X')
			{
			string = UnkownFormat(stringEnd,convertionSpec);
			}
		else // an integer conversion...
			{
			// clip precission so we don't overflow our buffer...
			const int MaxPrecision = FormatTextBufferSize-2; // allow 2 extra for hex prefix '0x'
			int precision = convertionSpec.Precision;
			if(precision>MaxPrecision)
				precision = MaxPrecision;
			if(precision>=0)
				flags &= ~FlagZero;

			if(c=='d' || c=='i')
				{
				longlong val = convertionSpec.GetIntArg(flags|SignedInt);
				char sign = 0;
				if(val<0)
					{
					sign = '-';
					val = -val;
					}
				else if(flags&FlagPlus)
					sign = '+';
				else if(flags&FlagSpace)
					sign = ' ';
				string = PushDecimal(val,string,precision);
				if(sign)
					{
					zeroPadPosition = 1;
					*--string = sign;
					}
				}
			else
				{
				ulonglong val = convertionSpec.GetIntArg(flags);
				if(c=='u')
					{
					string = PushDecimal(val,string,precision);
					}
				else if(c=='o')
					{
					string = PushOctal(val,string,precision,flags&FlagHash);
					}
				else // c=='x' || c=='X'
					{
					if(flags&FlagHash)
						zeroPadPosition = 2; // position after any '0x' prefix
					string = PushHex(val,string,precision,c,flags&FlagHash);
					}
				}
			}

		// work out where we need to insert padding to get required field width...
		char* padPoint;
		char pad = ' ';
		if(flags&FlagMinus)
			{
			padPoint = stringEnd;
			}
		else if(flags&FlagZero)
			{
			pad = '0';
			padPoint = string+zeroPadPosition;
			if(padPoint<string || padPoint>stringEnd)
				padPoint = stringEnd;
			}
		else
			{
			padPoint = string;
			}

		// calculate width of converted argument....
		int width = stringEnd-string;

		// add produced string (up to zeroPadPosition)...
		size = padPoint-string;
		outSize += size;
		Out(string,size);
		string += size;

		// pad (if required)...
		int fieldWidth = convertionSpec.FieldWidth;
		if(width<fieldWidth)
			{
			size = fieldWidth-width;
			outSize += size;
			Out(pad,size);
			}

		// add rest of produced string...
		size = stringEnd-string;
		outSize += size;
		Out(string,size);
		string += size;
		}
	}


size_t StringFormatter::Format(const char* formatString, ...)
	{
	va_list args;
	va_start(args,formatString);
	size_t size = VFormat(formatString,args);
	va_end(args);
	return size;
	}


ASSERT_COMPILE(16/sizeof(void*)*sizeof(void*)==16); // code assumes sizeof(void*) is a factor of 16

size_t StringFormatter::HexDumpLine(const void* data, size_t size, ptrdiff_t addressOffset)
	{
	const uint8_t* src = (const uint8_t*)data;
	char bytes[16];

	Format("%0*x",sizeof(void*)*2,src+addressOffset);

	for(unsigned i=0; i<16; i++)
		{
		const char* format;
		unsigned b;
		if(i<size)
			{
			b = *src++;
			format = "%*.2x";
			}
		else
			{
			b = ' ';
			format = "%*c";
			}
		int width = (i&(sizeof(uintptr_t)-1))==0 ? 4 : 3;

		Format(format,width,b);

		if(b<32 || b>0x7e)
			b = '.';
		bytes[i] = b;
		}

	Format("  %.16s\n",bytes);

	return size>16 ? size-16 : 0;
	}



//
// StringBufferFormatter
//

StringBufferFormatter::StringBufferFormatter(char* buffer, size_t size)
	{
	BufferStart = buffer;
	BufferPtr = buffer;
	char* end = buffer+size;
	if(end<buffer)
		end = (char*)~(uintptr_t)0; // clip buffer to end of memory
	BufferEnd = end;
	}


void StringBufferFormatter::Out(const char* text, size_t textSize)
	{
	char* out = BufferPtr;
	char* end = out+textSize;
	if(end>BufferEnd || end<out)
		end = BufferEnd;
	while(out<end)
		*out++ = *text++;
	BufferPtr = out;
	}


void StringBufferFormatter::Out(char character, size_t repeatCount)
	{
	char* out = BufferPtr;
	char* end = out+repeatCount;
	if(end>BufferEnd || end<out)
		end = BufferEnd;
	while(out<end)
		*out++ = character;
	BufferPtr = out;
	}


char* StringBufferFormatter::End()
	{
	char* end = BufferPtr;
	if(end<BufferEnd)
		*end = 0;
	BufferPtr = BufferStart;
	return end;
	}



//
// C Standard library functions
//

extern "C" int vsprintf(char* str, const char* format, va_list ap)
	{
	StringBufferFormatter formatter(str,~(size_t)0);
	int len = formatter.VFormat(format, ap);
	formatter.End();
	return len;
	}


extern "C" int sprintf(char* str, const char* format, ...)
	{
	va_list args;
	va_start(args,format);
	int len = vsprintf(str,format,args);
	va_end(args);
	return len;
	}


extern "C" int vsnprintf(char* str, size_t size, const char* format, va_list ap)
	{
	StringBufferFormatter formatter(str,size);
	size_t len = formatter.VFormat(format, ap);
	formatter.End();
	if(len>=size && size) // make sure we have a terminating nul
		str[size-1] = 0;
	return len;
	}


extern "C" int snprintf(char* str, size_t size, const char* format, ...)
	{
	va_list args;
	va_start(args,format);
	size = vsnprintf(str,size,format,args);
	va_end(args);
	return size;
	}


