#include "lexer.h"
#include "common/stringmap.h"
#include "scirpt/lexer.h"
#include "scirpt/position.h"
#include "scirpt/token.h"
#include <stdlib.h>
#include <string.h>

#define TT(type) ScirptTokenType##type

static inline void step(ScirptLexer* lexer) { scirpt_lexer_step(lexer); }
static inline ScirptToken
token(const ScirptLexer* lexer, ScirptTokenType type, ScirptPosition start)
{
	return scirpt_lexer_token(lexer, type, start);
}
static inline ScirptPosition pos(const ScirptLexer* lexer)
{
	return scirpt_lexer_pos(lexer);
}
static inline bool current_is(const ScirptLexer* lexer, char value)
{
	return scirpt_lexer_current_is(lexer, value);
}
static inline bool done(const ScirptLexer* lexer)
{
	return scirpt_lexer_done(lexer);
}
static inline char current(const ScirptLexer* lexer)
{
	return scirpt_lexer_current(lexer);
}

ScirptLexer* scirpt_lexer_new(const char* text, size_t text_length)
{
	ScirptLexer* lexer = malloc(sizeof(ScirptLexer));
	scirpt_lexer_create(lexer, text, text_length);
	return lexer;
}

void scirpt_lexer_delete(ScirptLexer* lexer) { free(lexer); }

static inline void
add_keyword(StringMap* keywords, const char* key, ScirptTokenType value)
{
	stringmap_set(keywords, key, strlen(key), value);
}

void scirpt_lexer_create(
	ScirptLexer* lexer, const char* text, size_t text_length
)
{
	StringMap* keywords = stringmap_new();
	add_keyword(keywords, "null", TT(Null));
	add_keyword(keywords, "false", TT(False));
	add_keyword(keywords, "true", TT(True));
	add_keyword(keywords, "not", TT(Not));
	add_keyword(keywords, "and", TT(And));
	add_keyword(keywords, "or", TT(Or));
	add_keyword(keywords, "let", TT(Let));
	add_keyword(keywords, "if", TT(If));
	add_keyword(keywords, "else", TT(Else));
	add_keyword(keywords, "while", TT(While));
	add_keyword(keywords, "for", TT(For));
	add_keyword(keywords, "in", TT(In));
	add_keyword(keywords, "break", TT(Break));
	add_keyword(keywords, "fn", TT(Fn));
	add_keyword(keywords, "return", TT(Return));
	*lexer = (ScirptLexer) {
		.text = text,
		.text_length = text_length,
		.index = 0,
		.line = 1,
		.col = 1,
		.keywords = keywords,
	};
}

void scirpt_lexer_destroy(ScirptLexer* lexer)
{
	stringmap_delete(lexer->keywords);
}

static inline bool is_whitespace(char value)
{
	return value == ' ' || value == '\t' || value == '\r' || value == '\n';
}

static inline bool is_id_char_excluding_numbers(char value)
{
	return (value >= 'a' && value <= 'z') || (value >= 'A' && value <= 'Z')
		|| value == '_';
}

static inline bool is_int_char(char value)
{
	return value >= '0' && value <= '9';
}

static inline bool is_id_char(char value)
{
	return is_id_char_excluding_numbers(value) || is_int_char(value);
}

ScirptToken scirpt_lexer_next(ScirptLexer* lexer)
{
	return scirpt_lexer_level_1(lexer);
}

ScirptToken scirpt_lexer_level_1(ScirptLexer* lexer)
{
	if (done(lexer))
		return token(lexer, TT(Eof), pos(lexer));
	else if (is_whitespace(current(lexer)))
		return scirpt_lexer_skip_whitespace(lexer);
	else if (is_id_char_excluding_numbers(current(lexer)))
		return scirpt_lexer_id_token(lexer);
	else
		return scirpt_lexer_level_2(lexer);
}

static inline ScirptToken single_token(ScirptLexer* lexer, ScirptTokenType type)
{
	ScirptPosition start = pos(lexer);
	step(lexer);
	return token(lexer, type, start);
}

static inline ScirptToken single_or_double_token(
	ScirptLexer* lexer,
	ScirptTokenType first_type,
	char second_char,
	ScirptTokenType second_type
)
{
	ScirptPosition start = pos(lexer);
	step(lexer);
	if (current_is(lexer, second_char)) {
		step(lexer);
		return token(lexer, second_type, start);
	} else {
		return token(lexer, first_type, start);
	}
}

ScirptToken scirpt_lexer_level_2(ScirptLexer* lexer)
{
	switch (current(lexer)) {
		case '0':
			return single_token(lexer, TT(Int));
		case '"':
			return scirpt_lexer_string_token(lexer);
		case '(':
			return single_token(lexer, TT(LParen));
		case ')':
			return single_token(lexer, TT(RParen));
		case '{':
			return single_token(lexer, TT(LBrace));
		case '}':
			return single_token(lexer, TT(RBrace));
		case '[':
			return single_token(lexer, TT(LBracket));
		case ']':
			return single_token(lexer, TT(RBracket));
		case '.':
			return single_token(lexer, TT(RBracket));
		case ',':
			return single_token(lexer, TT(RBracket));
		case ':':
			return single_token(lexer, TT(RBracket));
		case ';':
			return single_token(lexer, TT(RBracket));
		case '+':
			return single_or_double_token(lexer, TT(Plus), '=', TT(PlusEqual));
		case '-':
			return single_or_double_token(
				lexer, TT(Minus), '=', TT(MinusEqual)
			);
		case '*':
			return single_or_double_token(
				lexer, TT(Asterisk), '=', TT(AsteriskEqual)
			);
		case '/':
			return single_token(lexer, TT(RBracket));
		case '%':
			return single_or_double_token(
				lexer, TT(Percent), '=', TT(PercentEqual)
			);
		case '=':
			return single_or_double_token(
				lexer, TT(Equal), '=', TT(EqualEqual)
			);
		case '!':
			return single_or_double_token(
				lexer, TT(Exclamation), '=', TT(ExclamationEqual)
			);
		case '<':
			return single_or_double_token(lexer, TT(Lt), '=', TT(LtEqual));
		case '>':
			return single_or_double_token(lexer, TT(Gt), '=', TT(GtEqual));
		default:
			return scirpt_lexer_level_3(lexer);
	}
}

ScirptToken scirpt_lexer_level_3(ScirptLexer* lexer)
{
	if (is_int_char(current(lexer)))
		return scirpt_lexer_int_token(lexer);
	else
		return single_token(lexer, TT(InvalidChar));
}

ScirptToken scirpt_lexer_skip_whitespace(ScirptLexer* lexer)
{
	step(lexer);
	while (!done(lexer) && is_whitespace(current(lexer)))
		step(lexer);
	return scirpt_lexer_next(lexer);
}

ScirptToken scirpt_lexer_id_token(ScirptLexer* lexer)
{
	ScirptPosition start = pos(lexer);
	step(lexer);
	while (!done(lexer) && is_id_char(current(lexer)))
		step(lexer);
	size_t* found_keyword = stringmap_get(
		lexer->keywords, &lexer->text[start.index], lexer->index - start.index
	);
	if (found_keyword)
		return token(lexer, (ScirptTokenType)*found_keyword, start);
	else
		return token(lexer, TT(Id), start);
}

ScirptToken scirpt_lexer_int_token(ScirptLexer* lexer)
{
	ScirptPosition start = pos(lexer);
	step(lexer);
	while (!done(lexer) && is_int_char(current(lexer)))
		step(lexer);
	return token(lexer, TT(Int), start);
}

ScirptToken scirpt_lexer_string_token(ScirptLexer* lexer)
{
	ScirptPosition start = pos(lexer);
	step(lexer);
	while (!done(lexer) && current(lexer) != '\"') {
		char first = current(lexer);
		step(lexer);
		if (!done(lexer) && first == '\\')
			step(lexer);
	}
	if (!current_is(lexer, '"'))
		return token(lexer, TT(MalformedString), start);
	step(lexer);
	return token(lexer, TT(String), start);
}

ScirptToken scirpt_lexer_slash_token(ScirptLexer* lexer)
{

	ScirptPosition start = pos(lexer);
	step(lexer);
	if (current_is(lexer, TT(Slash))) {
		step(lexer);
		while (!done(lexer) && current(lexer) != '\n')
			step(lexer);
		return scirpt_lexer_next(lexer);
	} else if (current_is(lexer, TT(Asterisk))) {
		step(lexer);
		int depth = 0;
		char last_char = '\0';
		while (!done(lexer)) {
			if (last_char == '/' && current(lexer) == '*') {
				depth++;
			} else if (last_char == '*' && current(lexer) == '/') {
				depth--;
				if (depth == 0) {
					step(lexer);
					break;
				}
			}
			last_char = current(lexer);
			step(lexer);
		}
		if (depth != 0)
			return token(lexer, TT(MalformedComment), start);
		return scirpt_lexer_next(lexer);
	} else if (current_is(lexer, TT(Equal))) {
		step(lexer);
		return token(lexer, TT(SlashEqual), start);
	} else {
		return token(lexer, TT(Slash), start);
	}
}

void scirpt_lexer_step(ScirptLexer* lexer)
{
	lexer->index++;
	if (!done(lexer)) {
		if (current(lexer) == '\n') {
			lexer->line++;
			lexer->col = 1;
		} else {
			lexer->col++;
		}
	}
}
ScirptPosition scirpt_lexer_pos(const ScirptLexer* lexer)
{
	return (ScirptPosition) {
		.index = lexer->index,
		.line = lexer->line,
		.col = lexer->col,
	};
}

ScirptToken scirpt_lexer_token(
	const ScirptLexer* lexer, ScirptTokenType type, ScirptPosition start
)
{
	return (ScirptToken) {
		.type = type,
		.pos = start,
		.length = lexer->index - start.index,
	};
}

bool scirpt_lexer_current_is(const ScirptLexer* lexer, char value)
{
	return !done(lexer) && current(lexer) == value;
}

bool scirpt_lexer_done(const ScirptLexer* lexer)
{
	return lexer->index >= lexer->text_length;
}

char scirpt_lexer_current(const ScirptLexer* lexer)
{
	return lexer->text[lexer->index];
}