#include "MyUtil.h"
#include <wrl/client.h> //  ComPtr
#include <stdexcept>
using Microsoft::WRL::ComPtr;

struct InterAreaData
{
    std::unordered_map<std::string, ID3D11Texture2D*> textureSrc;
    ID3D11ComputeShader* cs_interarea = nullptr;
    ~InterAreaData()
    {
        for (auto& p : textureSrc)
        {
            auto& texture = p.second;
            if (texture)
                texture->Release();
        }
        textureSrc.clear();
        if (cs_interarea)
            cs_interarea->Release();
    }
}g_interAreaData;



ID3D11ComputeShader* GetInterAreaShader(MyDX11Renderer* pRenderer)
{
    if (g_interAreaData.cs_interarea == nullptr)
        pRenderer->MyCreateComputeShader(&g_interAreaData.cs_interarea, "./data/shader/sys_interarea.hlsl", "CSMain");
    return g_interAreaData.cs_interarea;
}

ID3D11Texture2D* GetInterAreaSrcTexture(const char* filename, MyDX11Renderer* pRenderer)
{
    auto iter = g_interAreaData.textureSrc.find(filename);
    if (iter == g_interAreaData.textureSrc.end())
    {
        ID3D11Texture2D* tex = nullptr;
        pRenderer->MyLoadTexture(&tex, filename);
        g_interAreaData.textureSrc.insert(std::make_pair(filename, tex));
        return tex;
    }
    return iter->second;
}

void ScaleImageWithComputeShader(
    ID3D11Texture2D** outTexture,
    ID3D11Texture2D* inTexture,
    MyDX11Renderer* pRenderer,
    ID3D11DeviceContext* pContext,
    float scaleW,
    float scaleH
) {
    if (!inTexture || !pRenderer || !pContext || scaleW <= 0 || scaleH <= 0) {
        throw std::invalid_argument("Invalid input parameters");
    }

    // ȡ
    D3D11_TEXTURE2D_DESC inDesc;
    inTexture->GetDesc(&inDesc);

    // ĳߴ
    UINT outWidth = static_cast<UINT>(inDesc.Width * scaleW);
    UINT outHeight = static_cast<UINT>(inDesc.Height * scaleH);

    // 
    D3D11_TEXTURE2D_DESC outDesc = {};
    outDesc.Width = outWidth;
    outDesc.Height = outHeight;
    outDesc.MipLevels = 1;
    outDesc.ArraySize = 1;
    outDesc.Format = inDesc.Format; // ͬĸʽRGBA32
    outDesc.SampleDesc.Count = 1;
    outDesc.SampleDesc.Quality = 0;
    outDesc.Usage = D3D11_USAGE_DEFAULT;
    outDesc.BindFlags = D3D11_BIND_UNORDERED_ACCESS | D3D11_BIND_SHADER_RESOURCE;
    outDesc.CPUAccessFlags = 0;
    outDesc.MiscFlags = 0;

    auto pDevice = pRenderer->GetDevice();
    // 
    HRESULT hr = pDevice->CreateTexture2D(&outDesc, nullptr, outTexture);
    if (FAILED(hr))
        throw std::runtime_error("Failed to create output texture");

    //  SRV
    ComPtr<ID3D11ShaderResourceView> pInputSRV;
    D3D11_SHADER_RESOURCE_VIEW_DESC srvDesc = {};
    srvDesc.Format = inDesc.Format;
    srvDesc.ViewDimension = D3D11_SRV_DIMENSION_TEXTURE2D;
    srvDesc.Texture2D.MipLevels = 1;
    hr = pDevice->CreateShaderResourceView(inTexture, &srvDesc, &pInputSRV);
    if (FAILED(hr))
        throw std::runtime_error("Failed to create input SRV");

    //  UAV
    ComPtr<ID3D11UnorderedAccessView> pOutputUAV;
    D3D11_UNORDERED_ACCESS_VIEW_DESC uavDesc = {};
    uavDesc.Format = outDesc.Format;
    uavDesc.ViewDimension = D3D11_UAV_DIMENSION_TEXTURE2D;
    uavDesc.Texture2D.MipSlice = 0;
    hr = pDevice->CreateUnorderedAccessView(*outTexture, &uavDesc, &pOutputUAV);
    if (FAILED(hr))
        throw std::runtime_error("Failed to create output UAV");

    //  Compute Shader
    auto pComputeShader = GetInterAreaShader(pRenderer);

    //  Compute Shader Դ
    pContext->CSSetShader(pComputeShader, nullptr, 0);
    pContext->CSSetShaderResources(0, 1, pInputSRV.GetAddressOf());
    pContext->CSSetUnorderedAccessViews(0, 1, pOutputUAV.GetAddressOf(), nullptr);

    // ߳
    UINT threadGroupX = (outWidth + 15) / 16;
    UINT threadGroupY = (outHeight + 15) / 16;

    // ִ Compute Shader
    pContext->Dispatch(threadGroupX, threadGroupY, 1);

    // Դ
    pContext->CSSetShader(nullptr, nullptr, 0);
    ID3D11ShaderResourceView* nullSRV = nullptr;
    pContext->CSSetShaderResources(0, 1, &nullSRV);
    ID3D11UnorderedAccessView* nullUAV = nullptr;
    pContext->CSSetUnorderedAccessViews(0, 1, &nullUAV, nullptr);
}


void LinearScarePNG(ID3D11Texture2D** outTexture, MyDX11Renderer* pRenderer, const char* filename, float scaleW, float scaleH)
{
    //ȡԭʼ
    ScaleImageWithComputeShader(outTexture, GetInterAreaSrcTexture(filename, pRenderer), pRenderer, pRenderer->GetContext(), scaleW, scaleH);
}