#include "arch.hpp"
#include "vm.hpp"
#include <format>
#include <iostream>
#include <string>
#include <utility>
#include <variant>

enum class AsmLineType {
    Op,
    Lit,
    Loc,
    Ref,
};

struct Loc {
    explicit Loc(std::string value)
        : value(value)
    {
    }

    std::string value;
};

struct Ref {
    explicit Ref(std::string value)
        : value(value)
    {
    }

    std::string value;
};

struct AsmLine {
    /* clang-format off */
    AsmLine(sliger::Op value) : type(AsmLineType::Op), value(value) {}
    AsmLine(uint32_t value) : type(AsmLineType::Lit), value(value) {}
    AsmLine(Loc value) : type(AsmLineType::Loc), value(value) {}
    AsmLine(Ref value) : type(AsmLineType::Ref), value(value) {}
    /* clang-format on */

    AsmLineType type;
    std::variant<sliger::Op, uint32_t, Loc, Ref> value;
};

auto compile_asm(const std::vector<AsmLine>& lines) -> std::vector<uint32_t>
{
    size_t ip = 0;
    auto output = std::vector<uint32_t>();
    auto locs = std::unordered_map<std::string, size_t>();
    auto refs = std::unordered_map<size_t, std::string>();
    for (const auto& line : lines) {
        switch (line.type) {
            case AsmLineType::Op: {
                output.push_back(
                    std::to_underlying(std::get<sliger::Op>(line.value)));
                ip += 1;
                break;
            }
            case AsmLineType::Lit: {
                output.push_back(std::get<uint32_t>(line.value));
                ip += 1;
                break;
            }
            case AsmLineType::Loc: {
                locs.insert_or_assign(std::get<Loc>(line.value).value, ip);
                break;
            }
            case AsmLineType::Ref: {
                output.push_back(0);
                refs.insert_or_assign(ip, std::get<Ref>(line.value).value);
                ip += 1;
                break;
            }
        }
    }
    for (size_t i = 0; i < output.size(); ++i) {
        if (!refs.contains(i)) {
            continue;
        }
        if (!locs.contains(refs.at(i))) {
            std::cerr << std::format(
                "error: label \"{}\" used at {} not defined\n", refs.at(i), i);
            continue;
        }
        output.at(i) = static_cast<uint32_t>(locs.at(refs.at(i)));
    }
    return output;
}

int main()
{
    using R = Ref;
    using L = Loc;
    using enum sliger::Op;

    //  fn add(a, b) {
    //      + a b
    //  }
    //
    //  fn main() {
    //      let result = 0;
    //      let i = 0;
    //      loop {
    //          if i >= 10 {
    //              break;
    //          }
    //          result = add(result, 5);
    //          i = + i 1;
    //      }
    //      result
    //  }
    auto program_asm = std::vector<AsmLine> {
        // clang-format off
        SourceMap,  0, 0, 0,
        PushPtr,    R("main"),
        Call,       0,
        PushPtr,    R("_exit"),
        Jump,
        Pop,
        L("add"),
        SourceMap,  19, 2, 5,
        Add,
        Return,
        L("main"),
        SourceMap,  28, 5, 1,
        PushInt,    0,
        PushInt,    0,
        SourceMap,  44, 6, 1,
        PushInt,    0,
        SourceMap,  55, 7, 1,
        L("0"),
        SourceMap,  66, 8,  5,
        LoadLocal,  2,
        PushInt,    10,
        LessThan,
        Not,
        PushPtr,    R("1"),
        JumpIfFalse,
        SourceMap,  87, 9, 9,
        PushPtr,    R("2"),
        Jump,
        L("1"),
        SourceMap,  104, 11, 5,
        LoadLocal,  1,
        PushInt,    5,
        PushPtr,    R("add"),
        Call,       2,
        StoreLocal, 1,
        SourceMap,  133, 12, 5,
        LoadLocal,  2,
        PushInt,    1,
        Add,
        StoreLocal, 2,
        PushPtr,    R("0"),
        Jump,
        L("2"),
        LoadLocal,  1,
        StoreLocal, 0,
        Pop,
        Pop,
        Return,
        L("_exit"),
        SourceMap,  147, 15, 1
        // clang-format on
    };
    auto program = compile_asm(program_asm);
    auto vm = sliger::VM(program,
        {
            .flame_graph = true,
            .code_coverage = true,
        });
    vm.run_until_done();
    std::cout << std::format("done\n{}\n", vm.stack_repr_string(4));
    auto flame_graph = vm.flame_graph_json();
    std::cout << std::format("flame graph: {}\n", flame_graph);
    auto code_coverage = vm.code_coverage_json();
    std::cout << std::format("code coverage: {}\n", code_coverage);
}