#include "vm.hpp"
#include "arch.hpp"
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <format>
#include <iostream>
#include <string>
#include <utility>
#include <vector>

using namespace sliger;

void VM::run_until_done()
{
    while (not done()) {
        run_instruction();
    }
    this->flame_graph.calculate_midway_result(this->instruction_counter);
}

void VM::run_n_instructions(size_t amount)
{
    for (size_t i = 0; i < amount and not done(); ++i) {
        run_instruction();
    }
    this->flame_graph.calculate_midway_result(this->instruction_counter);
}

void VM::run_instruction()
{
    if (this->opts.print_debug) {
        // std::cout << std::format("    {:>4}: {:<12}{}\n", this->pc,
        //     maybe_op_to_string(this->program[this->pc]),
        //     stack_repr_string(8));
        auto stack_frame_size = this->stack.size() - this->bp;
        std::cout << std::format("    {:>4}: {:<12}{}\n", this->pc,
            maybe_op_to_string(this->program[this->pc]),
            stack_repr_string(stack_frame_size));
    }
    auto op = eat_op();
    switch (op) {
        case Op::Nop:
            // nothing
            break;
        case Op::PushNull:
            this->stack.push_back(Null {});
            break;
        case Op::PushInt: {
            assert_program_has(1);
            auto value = eat_int32();
            this->stack.push_back(Int { value });
            break;
        }
        case Op::PushBool: {
            assert_program_has(1);
            auto value = eat_int32();
            this->stack.push_back(Bool { .value = value != 0 });
            break;
        }
        case Op::PushString: {
            assert_program_has(1);
            auto string_length = eat_uint32();
            assert_program_has(string_length);
            auto value = std::string();
            for (uint32_t i = 0; i < string_length; ++i) {
                auto ch = eat_uint32();
                value.push_back(static_cast<char>(ch));
            }
            stack_push(String { .value = std::move(value) });
            break;
        }
        case Op::PushPtr: {
            assert_program_has(1);
            auto value = eat_uint32();
            this->stack.push_back(Ptr { value });
            break;
        }
        case Op::Pop: {
            assert_stack_has(1);
            this->stack.pop_back();
            break;
        }
        case Op::ReserveStatic: {
            assert_program_has(1);
            auto value = eat_uint32();
            this->statics.reserve(value);
            break;
        }
        case Op::LoadStatic: {
            assert_program_has(1);
            auto loc = eat_uint32();
            auto value = this->statics.at(loc);
            stack_push(value);
            break;
        }
        case Op::StoreStatic: {
            assert_program_has(1);
            auto loc = eat_uint32();
            auto value = stack_pop();
            this->statics.at(loc) = value;
            break;
        }
        case Op::LoadLocal: {
            assert_program_has(1);
            auto loc = eat_uint32();
            assert_fn_stack_has(loc);
            auto value = fn_stack_at(loc);
            stack_push(value);
            break;
        }
        case Op::StoreLocal: {
            assert_program_has(1);
            auto loc = eat_uint32();
            assert_fn_stack_has(loc + 1);
            auto value = stack_pop();
            fn_stack_at(loc) = value;
            break;
        }
        case Op::Call: {
            assert_program_has(1);
            auto arg_count = eat_uint32();
            assert_stack_has(arg_count + 1);
            auto fn_ptr = stack_pop();
            auto arguments = std::vector<Value>();
            for (uint32_t i = 0; i < arg_count; ++i) {
                arguments.push_back(stack_pop());
            }
            stack_push(Ptr { .value = this->pc });
            stack_push(Ptr { .value = this->bp });
            this->pc = fn_ptr.as_ptr().value;
            this->bp = static_cast<uint32_t>(this->stack.size());
            for (auto&& arg = arguments.rbegin(); arg != arguments.rend();
                 ++arg) {
                stack_push(*arg);
            }
            if (this->opts.flame_graph) {
                this->flame_graph.report_call(
                    fn_ptr.as_ptr().value, this->instruction_counter);
            }
            break;
        }
        case Op::Return: {
            assert_stack_has(3);
            auto ret_val = stack_pop();
            while (this->stack.size() > this->bp) {
                stack_pop();
            }
            auto bp_val = stack_pop();
            auto pc_val = stack_pop();
            this->bp = bp_val.as_ptr().value;
            stack_push(ret_val);
            this->pc = pc_val.as_ptr().value;
            if (this->opts.flame_graph) {
                this->flame_graph.report_return(this->instruction_counter);
            }
            break;
        }
        case Op::Jump: {
            assert_stack_has(1);
            auto addr = stack_pop();
            this->pc = addr.as_ptr().value;
            break;
        }
        case Op::JumpIfTrue: {
            assert_stack_has(2);
            auto addr = stack_pop();
            auto cond = stack_pop();
            if (cond.as_bool().value) {
                this->pc = addr.as_ptr().value;
            }
            break;
        }
        case Op::Builtin: {
            assert_program_has(1);
            auto builtin_id = eat_uint32();
            run_builtin(static_cast<Builtin>(builtin_id));
            break;
        }
        case Op::Duplicate: {
            assert_stack_has(1);
            auto value = stack_pop();
            stack_push(value);
            stack_push(value);
            break;
        }
        case Op::Swap: {
            assert_stack_has(2);
            auto right = stack_pop();
            auto left = stack_pop();
            stack_push(right);
            stack_push(left);
            break;
        }
        case Op::Add: {
            assert_stack_has(2);
            auto right = stack_pop().as_int().value;
            auto left = stack_pop().as_int().value;
            auto value = left + right;
            stack_push(Int { .value = value });
            break;
        }
        case Op::Subtract: {
            assert_stack_has(2);
            auto right = stack_pop().as_int().value;
            auto left = stack_pop().as_int().value;
            auto value = left - right;
            stack_push(Int { .value = value });
            break;
        }
        case Op::Multiply: {
            assert_stack_has(2);
            auto right = stack_pop().as_int().value;
            auto left = stack_pop().as_int().value;
            auto value = left * right;
            stack_push(Int { .value = value });
            break;
        }
        case Op::Divide: {
            assert_stack_has(2);
            auto right = stack_pop().as_int().value;
            auto left = stack_pop().as_int().value;
            auto value = left / right;
            stack_push(Int { .value = value });
            break;
        }
        case Op::Remainder: {
            assert_stack_has(2);
            auto right = stack_pop().as_int().value;
            auto left = stack_pop().as_int().value;
            auto value = left % right;
            stack_push(Int { .value = value });
            break;
        }
        case Op::Equal: {
            assert_stack_has(2);
            auto right = stack_pop().as_int().value;
            auto left = stack_pop().as_int().value;
            auto value = left == right;
            stack_push(Bool { .value = value });
            break;
        }
        case Op::LessThan: {
            assert_stack_has(2);
            auto right = stack_pop().as_int().value;
            auto left = stack_pop().as_int().value;
            auto value = left < right;
            stack_push(Bool { .value = value });
            break;
        }
        case Op::And: {
            assert_stack_has(2);
            auto right = stack_pop().as_bool().value;
            auto left = stack_pop().as_bool().value;
            auto value = left && right;
            stack_push(Bool { .value = value });
            break;
        }
        case Op::Or: {
            assert_stack_has(2);
            auto right = stack_pop().as_bool().value;
            auto left = stack_pop().as_bool().value;
            auto value = left || right;
            stack_push(Bool { .value = value });
            break;
        }
        case Op::Xor: {
            assert_stack_has(2);
            auto right = stack_pop().as_bool().value;
            auto left = stack_pop().as_bool().value;
            auto value = (left || !right) || (!left && right);
            stack_push(Bool { .value = value });
            break;
        }
        case Op::Not: {
            assert_stack_has(1);
            auto value = !stack_pop().as_bool().value;
            stack_push(Bool { .value = value });
            break;
        }
        case Op::SourceMap: {
            assert_program_has(3);
            auto index = eat_int32();
            auto line = eat_int32();
            auto col = eat_int32();
            if (opts.code_coverage) {
                this->code_coverage.report_cover(this->current_pos);
            }
            this->current_pos = { index, line, col };
            break;
        }
    }
    this->instruction_counter += 1;
}

void VM::run_builtin(Builtin builtin_id)
{
    if (this->opts.print_debug) {
        std::cout << std::format("Running builtin {}\n",
            maybe_builtin_to_string(static_cast<uint32_t>(builtin_id)));
    }
    switch (builtin_id) {
        case Builtin::Exit: {
            assert_stack_has(1);
            auto status_code = stack_pop().as_int().value;
            std::exit(status_code);
            break;
        }
        case Builtin::IntToString: {
            assert_stack_has(1);
            auto number = stack_pop().as_int().value;
            auto str = std::to_string(number);
            stack_push(String(str));
            break;
        }

        case Builtin::StringConcat:
        case Builtin::StringEqual:
        case Builtin::StringCharAt:
        case Builtin::StringLength:
        case Builtin::StringPushChar:
        case Builtin::StringToInt:
            run_string_builtin(builtin_id);
            break;

        case Builtin::ArrayNew:
        case Builtin::ArraySet:
        case Builtin::ArrayPush:
        case Builtin::ArrayAt:
        case Builtin::ArrayLength:
            run_array_builtin(builtin_id);
            break;

        case Builtin::StructSet: {
            assert_stack_has(2);
            std::cerr << std::format("not implemented\n");
            std::exit(1);
            break;
        }

        case Builtin::Print:
        case Builtin::FileOpen:
        case Builtin::FileClose:
        case Builtin::FileWriteString:
        case Builtin::FileReadChar:
        case Builtin::FileReadToString:
        case Builtin::FileFlush:
        case Builtin::FileEof:
            run_file_builtin(builtin_id);
            break;
    }
}

void VM::run_string_builtin(Builtin builtin_id)
{
    switch (builtin_id) {
        case Builtin::StringConcat: {
            assert_stack_has(2);
            auto right = stack_pop();
            auto left = stack_pop();
            stack_push(
                String(left.as_string().value + right.as_string().value));
            break;
        }
        case Builtin::StringEqual: {
            assert_stack_has(2);
            auto right = stack_pop();
            auto left = stack_pop();
            stack_push(Bool(left.as_string().value == right.as_string().value));
            break;
        }
        case Builtin::StringCharAt: {
            assert_stack_has(2);
            auto index_value = stack_pop();
            auto string_value = stack_pop();
            auto index = static_cast<int32_t>(index_value.as_int().value);
            auto string = string_value.as_string();
            stack_push(Int(string.at(index)));
            break;
        }
        case Builtin::StringLength: {
            assert_stack_has(1);
            auto str = stack_pop().as_string().value;

            auto length = static_cast<int32_t>(str.length());
            stack_push(Int(length));
            break;
        }
        case Builtin::StringPushChar: {
            assert_stack_has(2);
            auto ch = stack_pop();
            auto str = stack_pop();

            auto new_str = std::string(str.as_string().value);
            new_str.push_back(static_cast<char>(ch.as_int().value));
            stack_push(String(new_str));
            break;
        }
        case Builtin::StringToInt: {
            assert_stack_has(1);
            auto str = stack_pop().as_string().value;
            auto number = atoi(str.c_str());
            stack_push(Int(number));
            break;
        }
        default:
            break;
    }
}
void VM::run_array_builtin(Builtin builtin_id)
{
    switch (builtin_id) {
        case Builtin::ArrayNew: {
            auto alloc_res = this->heap.alloc<heap::AllocType::Array>();
            stack_push(Ptr(alloc_res.val()));
            break;
        }
        case Builtin::ArraySet: {
            assert_stack_has(2);
            auto index = stack_pop().as_int().value;
            auto array_ptr = stack_pop().as_ptr().value;
            auto value = stack_pop();

            this->heap.at(array_ptr).val()->as_array().at(index) = value;
            stack_push(Null());
            break;
        }
        case Builtin::ArrayPush: {
            assert_stack_has(2);
            auto value = stack_pop();
            auto array_ptr = stack_pop().as_ptr().value;

            this->heap.at(array_ptr).val()->as_array().values.push_back(value);
            stack_push(Null());
            break;
        }
        case Builtin::ArrayAt: {
            assert_stack_has(2);
            auto index = stack_pop().as_int().value;
            auto array_ptr = stack_pop().as_ptr().value;

            auto array = this->heap.at(array_ptr).val()->as_array();
            stack_push(array.at(index));
            break;
        }
        case Builtin::ArrayLength: {
            assert_stack_has(1);
            auto array_ptr = stack_pop().as_ptr().value;

            auto array = this->heap.at(array_ptr).val()->as_array();
            stack_push(Int(static_cast<int32_t>(array.values.size())));
            break;
        }
        default:
            break;
    }
}

void VM::run_file_builtin(Builtin builtin_id)
{
    switch (builtin_id) {
        case Builtin::Print: {
            assert_stack_has(1);
            auto message = stack_pop().as_string().value;
            std::cout << message;
            stack_push(Null());
            break;
        }
        case Builtin::FileOpen: {
            assert_stack_has(2);
            auto mode = stack_pop().as_string().value;
            auto filename = stack_pop().as_string().value;
            FILE* fp = std::fopen(filename.c_str(), mode.c_str());
            if (fp == nullptr) {
                std::cerr << std::format(
                    "error: could not open file '{}'\n", filename);
                std::exit(1);
            }
            auto file_id = this->file_id_counter;
            this->file_id_counter += 1;
            this->open_files.insert_or_assign(file_id, fp);
            stack_push(Int(file_id));
            break;
        }
        case Builtin::FileClose: {
            assert_stack_has(2);
            auto file_id = stack_pop().as_int().value;
            auto fp = this->open_files.find(file_id);
            if (fp != this->open_files.end()) {
                std::fclose(fp->second);
                this->open_files.erase(file_id);
            }
            stack_push(Null());
            break;
        }
        case Builtin::FileWriteString: {
            assert_stack_has(2);
            auto content = stack_pop().as_string().value;
            auto file_id = stack_pop().as_int().value;
            auto fp = this->open_files.find(file_id);
            if (fp == this->open_files.end()) {
                std::cerr << std::format("error: no open file {}\n", file_id);
                std::exit(1);
            }
            auto res = std::fputs(content.c_str(), fp->second);
            if (res <= 0) {
                stack_push(Int(-1));
                break;
            }
            stack_push(Int(0));
            break;
        }
        case Builtin::FileReadChar: {
            assert_stack_has(1);
            auto file_id = stack_pop().as_int().value;
            auto fp = this->open_files.find(file_id);
            if (fp == this->open_files.end()) {
                std::cerr << std::format("error: no open file {}\n", file_id);
                std::exit(1);
            }
            int value = std::fgetc(fp->second);
            stack_push(Int(value));
            break;
        }
        case Builtin::FileReadToString: {
            assert_stack_has(1);
            auto file_id = stack_pop().as_int().value;
            auto fp = this->open_files.find(file_id);
            if (fp == this->open_files.end()) {
                std::cerr << std::format("error: no open file {}\n", file_id);
                std::exit(1);
            }
            auto content = std::string();
            while (true) {
                constexpr size_t buf_size = 129;
                char buf[buf_size] = "";
                auto res = std::fread(buf, 1, buf_size - 1, fp->second);
                if (res == 0) {
                    break;
                }
                buf[res] = '\0';
                content.append(std::string(buf));
            }
            stack_push(String(std::move(content)));
            break;
        }
        case Builtin::FileFlush: {
            assert_stack_has(1);
            auto file_id = stack_pop().as_int().value;
            auto fp = this->open_files.find(file_id);
            if (fp == this->open_files.end()) {
                std::cerr << std::format("error: no open file {}\n", file_id);
                std::exit(1);
            }
            std::fflush(fp->second);
            stack_push(Null());
            break;
        }
        case Builtin::FileEof: {
            assert_stack_has(1);
            auto file_id = stack_pop().as_int().value;
            auto fp = this->open_files.find(file_id);
            if (fp == this->open_files.end()) {
                std::cerr << std::format("error: no open file {}\n", file_id);
                std::exit(1);
            }
            stack_push(Bool(std::feof(fp->second) != 0));
            break;
        }
        default:
            break;
    }
}

auto VM::stack_repr_string(size_t max_items) const -> std::string
{
    auto result = std::string();
    result += "→";
    const auto& stack = view_stack();
    for (size_t i = 0; i < stack.size() and i < max_items; ++i) {
        if (i != 0) {
            result += " ";
        }
        result += std::format(
            "{:<11}", stack[stack.size() - i - 1].to_repr_string());
    }
    if (stack.size() > max_items) {
        result += std::format(" ... + {}", stack.size() - max_items);
    }
    return result;
};

void VM::assert_program_has(size_t count)
{
    if (this->pc + count > program.size()) {
        std::cerr << std::format("malformed program, pc = {}", this->pc);
        std::exit(1);
    }
}

void VM::assert_fn_stack_has(size_t count)
{
    if (this->stack.size() - this->bp < count) {
        std::cerr << std::format("stack underflow, pc = {}\n", this->pc);
        std::exit(1);
    }
}

void VM::assert_stack_has(size_t count)
{
    if (this->stack.size() < count) {
        std::cerr << std::format("stack underflow, pc = {}\n", this->pc);
        std::exit(1);
    }
}