#include "ScanLua.h"

fs::path GetExeDir()
{
    wchar_t path[MAX_PATH];
    GetModuleFileNameW(nullptr, path, MAX_PATH);
    return fs::path(path).parent_path();
}

std::string ReadFileBinary(const std::filesystem::path& path)
{
    std::ifstream f(path, std::ios::binary);
    return std::string(
        (std::istreambuf_iterator<char>(f)),
        std::istreambuf_iterator<char>()
    );
}

std::wstring Utf8ToWide(const std::string& utf8)
{
    if (utf8.empty())
        return {};

    // 賤
    int len = MultiByteToWideChar(
        CP_UTF8,
        MB_ERR_INVALID_CHARS,
        utf8.data(),
        (int)utf8.size(),
        nullptr,
        0
    );

    if (len <= 0)
        return {};

    std::wstring wide(len, L'\0');

    MultiByteToWideChar(
        CP_UTF8,
        MB_ERR_INVALID_CHARS,
        utf8.data(),
        (int)utf8.size(),
        wide.data(),
        len
    );

    return wide;
}

void CollectCharsFromWide(
    const std::wstring& ws,
    std::set<char32_t>& outChars)
{
    for (size_t i = 0; i < ws.size(); ++i)
    {
        wchar_t wc = ws[i];

        // ˿ַ
        if (wc < 0x20 && wc != L'\n' && wc != L'\t')
            continue;

        // UTF-16 ԣemoji ȣֱ
        if (wc >= 0xD800 && wc <= 0xDFFF)
            continue;

        outChars.insert(static_cast<char32_t>(wc));
    }
}

void ScanLuaFiles(const std::filesystem::path& root, std::set<char32_t>& outChars)
{
    for (auto& entry : std::filesystem::recursive_directory_iterator(root))
    {
        if (!entry.is_regular_file())
            continue;

        if (entry.path().extension() != ".lua")
            continue;

        std::string utf8 = ReadFileBinary(entry.path());
        std::wstring wide = Utf8ToWide(utf8);

        CollectCharsFromWide(wide, outChars);
    }
}

void WriteCharSet(const std::set<char32_t>& chars, const std::filesystem::path& outPath)
{
    std::wstring utf16;
    utf16.reserve(chars.size());

    for (char32_t c : chars)
    {
        if (c < 0x20 && c != U'\n' && c != U'\t')
            continue;

        if (c > 0x10FFFF)
            continue;

        if (c <= 0xFFFF)
        {
            if (c >= 0xD800 && c <= 0xDFFF)
                continue;

            utf16.push_back(static_cast<wchar_t>(c));
        }
        else
        {
            c -= 0x10000;
            wchar_t high = static_cast<wchar_t>(0xD800 + (c >> 10));
            wchar_t low = static_cast<wchar_t>(0xDC00 + (c & 0x3FF));
            utf16.push_back(high);
            utf16.push_back(low);
        }
    }

    if (utf16.empty())
        return;

    int size = WideCharToMultiByte(
        CP_UTF8,
        0,
        utf16.data(),
        static_cast<int>(utf16.size()),
        nullptr,
        0,
        nullptr,
        nullptr
    );

    if (size <= 0)
        return;

    std::string utf8(size, '\0');

    WideCharToMultiByte(
        CP_UTF8,
        0,
        utf16.data(),
        static_cast<int>(utf16.size()),
        utf8.data(),
        size,
        nullptr,
        nullptr
    );

    std::ofstream out(outPath, std::ios::binary);
    out.write(utf8.data(), utf8.size());
}