#include "lexer.h"
#include <ctype.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

struct Lexer {
    const char* text;
    size_t index, length;
    int line, column;
};

Token lexer_skip_whitespace(Lexer* lexer);
Token lexer_make_int_or_float(Lexer* lexer);
Token lexer_make_id(Lexer* lexer);
bool lexer_span_matches(const Lexer* lexer, Position begin, const char* value);
Token lexer_make_static_token(Lexer* lexer);
Token lexer_make_int_hex_or_binary(Lexer* lexer);
Token lexer_make_char(Lexer* lexer);
Token lexer_make_string(Lexer* lexer);
void lexer_skip_literal_char(Lexer* lexer);
Token lexer_make_single_char_token(Lexer* lexer, TokenType type);
Token lexer_make_slash_token(Lexer* lexer);
Token lexer_skip_singleline_comment(Lexer* lexer);
Token lexer_make_single_or_double_char_token(
    Lexer* lexer, TokenType single_type, char second_char, TokenType double_type);
Token lexer_skip_multiline_comment(Lexer* lexer);
Token lexer_make_invalid_char(Lexer* lexer);
Position lexer_position(const Lexer* lexer);
Token lexer_token(const Lexer* lexer, TokenType type, Position begin);
bool lexer_done(const Lexer* lexer);
char lexer_current(const Lexer* lexer);
void lexer_step(Lexer* lexer);

void lexer_create(Lexer* lexer, const char* text, size_t text_length)
{
    *lexer = (Lexer) {
        .text = text,
        .length = text_length,
        .line = 1,
        .column = 1,
    };
}

Token lexer_next(Lexer* lexer)
{
    char c = lexer_current(lexer);
    if (lexer_done(lexer))
        return lexer_token(lexer, TokenTypeEof, lexer_position(lexer));
    else if (isspace(c))
        return lexer_skip_whitespace(lexer);
    else if (c >= '1' && c <= '9')
        return lexer_make_int_or_float(lexer);
    else if (isalpha(c) || c == '_')
        return lexer_make_id(lexer);
    else
        return lexer_make_static_token(lexer);
}

Token lexer_skip_whitespace(Lexer* lexer)
{
    lexer_step(lexer);
    while (!lexer_done(lexer) && isspace(lexer_current(lexer)))
        lexer_step(lexer);
    return lexer_next(lexer);
}

Token lexer_make_int_or_float(Lexer* lexer)
{
    Position begin = lexer_position(lexer);
    lexer_step(lexer);
    while (!lexer_done(lexer) && isdigit(lexer_current(lexer)))
        lexer_step(lexer);
    if (!lexer_done(lexer) && lexer_current(lexer) == '.') {
        while (!lexer_done(lexer) && isdigit(lexer_current(lexer)))
            lexer_step(lexer);
        return lexer_token(lexer, TokenTypeFloat, begin);
    } else {
        return lexer_token(lexer, TokenTypeInt, begin);
    }
}

Token lexer_make_id(Lexer* lexer)
{
    Position begin = lexer_position(lexer);
    lexer_step(lexer);
    while (!lexer_done(lexer)
        && (isalpha(lexer_current(lexer)) || isdigit(lexer_current(lexer))
            || lexer_current(lexer) == '_'))
        lexer_step(lexer);
    if (lexer_span_matches(lexer, begin, "if"))
        return lexer_token(lexer, TokenTypeIf, begin);
    else if (lexer_span_matches(lexer, begin, "else"))
        return lexer_token(lexer, TokenTypeElse, begin);
    else if (lexer_span_matches(lexer, begin, "while"))
        return lexer_token(lexer, TokenTypeWhile, begin);
    else if (lexer_span_matches(lexer, begin, "break"))
        return lexer_token(lexer, TokenTypeBreak, begin);
    else
        return lexer_token(lexer, TokenTypeId, begin);
}

bool lexer_span_matches(const Lexer* lexer, Position begin, const char* value)
{
    size_t length = lexer->index - begin.index;
    if (length != strlen(value))
        return false;
    return strncmp(&lexer->text[begin.index], value, length) == 0;
}

Token lexer_make_static_token(Lexer* lexer)
{
    switch (lexer_current(lexer)) {
        case '0':
            return lexer_make_int_hex_or_binary(lexer);
        case '\'':
            return lexer_make_char(lexer);
        case '"':
            return lexer_make_string(lexer);
        case '(':
            return lexer_make_single_char_token(lexer, TokenTypeLParen);
        case ')':
            return lexer_make_single_char_token(lexer, TokenTypeRParen);
        case '{':
            return lexer_make_single_char_token(lexer, TokenTypeLBrace);
        case '}':
            return lexer_make_single_char_token(lexer, TokenTypeRBrace);
        case '[':
            return lexer_make_single_char_token(lexer, TokenTypeLBracket);
        case ']':
            return lexer_make_single_char_token(lexer, TokenTypeRBracket);
        case '.':
            return lexer_make_single_char_token(lexer, TokenTypeDot);
        case ',':
            return lexer_make_single_char_token(lexer, TokenTypeComma);
        case ':':
            return lexer_make_single_char_token(lexer, TokenTypeColon);
        case ';':
            return lexer_make_single_char_token(lexer, TokenTypeSemicolon);
        case '+':
            return lexer_make_single_or_double_char_token(
                lexer, TokenTypePlus, '=', TokenTypePlusEqual);
        case '-':
            return lexer_make_single_or_double_char_token(
                lexer, TokenTypeMinus, '=', TokenTypeMinusEqual);
        case '*':
            return lexer_make_single_or_double_char_token(
                lexer, TokenTypeAsterisk, '=', TokenTypeAsteriskEqual);
        case '/':
            return lexer_make_slash_token(lexer);
        case '%':
            return lexer_make_single_or_double_char_token(
                lexer, TokenTypePercent, '=', TokenTypePercentEqual);
        case '=':
            return lexer_make_single_or_double_char_token(
                lexer, TokenTypeEqual, '=', TokenTypeDoubleEqual);
        case '!':
            return lexer_make_single_or_double_char_token(
                lexer, TokenTypeExclamation, '=', TokenTypeExclamationEqual);
        case '<':
            return lexer_make_single_or_double_char_token(
                lexer, TokenTypeLt, '=', TokenTypeLtEqual);
        case '>':
            return lexer_make_single_or_double_char_token(
                lexer, TokenTypeGt, '=', TokenTypeGtEqual);
        default:
            return lexer_make_invalid_char(lexer);
    }
}

Token lexer_make_int_hex_or_binary(Lexer* lexer)
{
    Position begin = lexer_position(lexer);
    lexer_step(lexer);
    if (!lexer_done(lexer) && (lexer_current(lexer) == 'x' || lexer_current(lexer) == 'X')) {
        while (!lexer_done(lexer)
            && (isdigit(lexer_current(lexer))
                || (lexer_current(lexer) >= 'a' || lexer_current(lexer) <= 'f')
                || (lexer_current(lexer) >= 'A' || lexer_current(lexer) <= 'F')))
            lexer_step(lexer);
        return lexer_token(lexer, TokenTypeHex, begin);
    } else if (!lexer_done(lexer) && (lexer_current(lexer) == 'b' || lexer_current(lexer) == 'B')) {
        while (!lexer_done(lexer) && (lexer_current(lexer) == '0' || lexer_current(lexer) == '1'))
            lexer_step(lexer);
        return lexer_token(lexer, TokenTypeBinary, begin);
    } else {
        return lexer_token(lexer, TokenTypeInt, begin);
    }
}

Token lexer_make_char(Lexer* lexer)
{
    Position begin = lexer_position(lexer);
    lexer_step(lexer);
    if (lexer_done(lexer))
        return lexer_token(lexer, TokenTypeMalformedChar, begin);
    lexer_skip_literal_char(lexer);
    if (lexer_done(lexer) && lexer_current(lexer) != '\'')
        return lexer_token(lexer, TokenTypeMalformedChar, begin);
    lexer_step(lexer);
    return lexer_token(lexer, TokenTypeChar, begin);
}

Token lexer_make_string(Lexer* lexer)
{
    Position begin = lexer_position(lexer);
    lexer_step(lexer);
    if (lexer_done(lexer))
        return lexer_token(lexer, TokenTypeMalformedString, begin);
    while (!lexer_done(lexer) && lexer_current(lexer) != '\"')
        lexer_skip_literal_char(lexer);
    if (lexer_done(lexer) && lexer_current(lexer) != '\"')
        return lexer_token(lexer, TokenTypeMalformedString, begin);
    lexer_step(lexer);
    return lexer_token(lexer, TokenTypeChar, begin);
}

void lexer_skip_literal_char(Lexer* lexer)
{
    if (lexer_current(lexer) != '\\') {
        lexer_step(lexer);
        return;
    }
    lexer_step(lexer);
    if (lexer_done(lexer))
        return;
    char previous = lexer_current(lexer);
    lexer_step(lexer);
    if (previous >= '1' && previous <= '9') {
        while (!lexer_done(lexer) && isdigit(lexer_current(lexer)))
            lexer_step(lexer);
    } else if (previous == 'x' || previous == 'X') {
        while (!lexer_done(lexer)
            && (isdigit(lexer_current(lexer))
                || (lexer_current(lexer) >= 'a' && lexer_current(lexer) <= 'f')
                || (lexer_current(lexer) >= 'A' && lexer_current(lexer) <= 'F')))
            lexer_step(lexer);
    }
}

Token lexer_make_single_char_token(Lexer* lexer, TokenType type)
{
    Position begin = lexer_position(lexer);
    lexer_step(lexer);
    return lexer_token(lexer, type, begin);
}

Token lexer_make_single_or_double_char_token(
    Lexer* lexer, TokenType single_type, char second_char, TokenType double_type)
{
    Position begin = lexer_position(lexer);
    lexer_step(lexer);
    if (!lexer_done(lexer) && lexer_current(lexer) == second_char) {
        lexer_step(lexer);
        return lexer_token(lexer, single_type, begin);
    } else {
        return lexer_token(lexer, double_type, begin);
    }
}

Token lexer_make_slash_token(Lexer* lexer)
{
    Position begin = lexer_position(lexer);
    lexer_step(lexer);
    switch (lexer_current(lexer)) {
        case '/':
            return lexer_skip_singleline_comment(lexer);
        case '*':
            return lexer_skip_multiline_comment(lexer);
        case '=':
            lexer_step(lexer);
            return lexer_token(lexer, TokenTypeSlashEqual, begin);
        default:
            return lexer_token(lexer, TokenTypeSlash, begin);
    }
}

Token lexer_skip_singleline_comment(Lexer* lexer)
{
    lexer_step(lexer);
    while (!lexer_done(lexer) && lexer_current(lexer) != '\n')
        lexer_step(lexer);
    if (!lexer_done(lexer) && lexer_current(lexer) == '\n')
        lexer_step(lexer);
    return lexer_next(lexer);
}

Token lexer_skip_multiline_comment(Lexer* lexer)
{
    lexer_step(lexer);
    int depth = 1;
    while (!lexer_done(lexer)) {
        if (lexer_current(lexer) == '/') {
            lexer_step(lexer);
            if (!lexer_done(lexer) && lexer_current(lexer) == '*')
                depth += 1;
        } else if (lexer_current(lexer) == '*') {
            lexer_step(lexer);
            if (lexer_done(lexer) && lexer_current(lexer) == '/')
                depth -= 1;
        }
        lexer_step(lexer);
    }
    return depth != 0
        ? lexer_token(lexer, TokenTypeMalformedMultilineComment, lexer_position(lexer))
        : lexer_next(lexer);
}

Token lexer_make_invalid_char(Lexer* lexer)
{
    Position begin = lexer_position(lexer);
    lexer_step(lexer);
    return lexer_token(lexer, TokenTypeInvalidChar, begin);
}

Position lexer_position(const Lexer* lexer)
{
    return (Position) {
        .index = lexer->index,
        .line = lexer->line,
        .column = lexer->column,
    };
}

Token lexer_token(const Lexer* lexer, TokenType type, Position begin)
{
    return (Token) {
        .type = type,
        .position = begin,
        .length = lexer->index - begin.index,
    };
}

bool lexer_done(const Lexer* lexer) { return lexer->index >= lexer->length; }

char lexer_current(const Lexer* lexer) { return lexer->text[lexer->index]; }

void lexer_step(Lexer* lexer)
{
    if (lexer_done(lexer))
        return;
    if (lexer_current(lexer) == '\n') {
        lexer->line += 1;
        lexer->column = 1;
    } else {
        lexer->column += 1;
    }
    lexer->index += 1;
}

char* token_string(const Token* token, const char* text)
{
    char* value = calloc(token->length + 1, sizeof(char));
    strncpy(value, &text[token->position.index], token->length);
    return value;
}

char* token_to_string(const Token* token, const char* text)
{
    const char* type_string = token_type_to_string(token->type);
    char* value_string = token_string(token, text);
    size_t size = token->length + strlen(type_string) + 5;
    char* value = calloc(size, sizeof(char));
    snprintf(value, size, "(%s, %s)", type_string, value_string);
    free(value_string);
    return value;
}

const char* token_type_to_string(TokenType type)
{
    switch (type) {
        case TokenTypeEof:
            return "Eof";
        case TokenTypeInvalidChar:
            return "InvalidChar";
        case TokenTypeMalformedMultilineComment:
            return "MalformedMultilineComment";
        case TokenTypeMalformedChar:
            return "MalformedChar";
        case TokenTypeMalformedString:
            return "MalformedString";
        case TokenTypeId:
            return "Id";
        case TokenTypeInt:
            return "Int";
        case TokenTypeHex:
            return "Hex";
        case TokenTypeBinary:
            return "Binary";
        case TokenTypeFloat:
            return "Float";
        case TokenTypeChar:
            return "Char";
        case TokenTypeString:
            return "String";
        case TokenTypeIf:
            return "If";
        case TokenTypeElse:
            return "Else";
        case TokenTypeWhile:
            return "While";
        case TokenTypeBreak:
            return "Break";
        case TokenTypeLParen:
            return "LParen";
        case TokenTypeRParen:
            return "RParen";
        case TokenTypeLBrace:
            return "LBrace";
        case TokenTypeRBrace:
            return "RBrace";
        case TokenTypeLBracket:
            return "LBracket";
        case TokenTypeRBracket:
            return "RBracket";
        case TokenTypeDot:
            return "Dot";
        case TokenTypeComma:
            return "Comma";
        case TokenTypeColon:
            return "Colon";
        case TokenTypeSemicolon:
            return "Semicolon";
        case TokenTypePlusEqual:
            return "PlusEqual";
        case TokenTypeMinusEqual:
            return "MinusEqual";
        case TokenTypeAsteriskEqual:
            return "AsteriskEqual";
        case TokenTypeSlashEqual:
            return "SlashEqual";
        case TokenTypePercentEqual:
            return "PercentEqual";
        case TokenTypeDoubleEqual:
            return "DoubleEqual";
        case TokenTypeExclamationEqual:
            return "ExclamationEqual";
        case TokenTypeLtEqual:
            return "LtEqual";
        case TokenTypeGtEqual:
            return "GtEqual";
        case TokenTypePlus:
            return "Plus";
        case TokenTypeMinus:
            return "Minus";
        case TokenTypeAsterisk:
            return "Asterisk";
        case TokenTypeSlash:
            return "Slash";
        case TokenTypePercent:
            return "Percent";
        case TokenTypeEqual:
            return "Equal";
        case TokenTypeExclamation:
            return "Exclamation";
        case TokenTypeLt:
            return "Lt";
        case TokenTypeGt:
            return "Gt";
    }
}