////////////////////////////////////////////////////////////////////////////////
// Filename: shader.hlsi
////////////////////////////////////////////////////////////////////////////////

//////////////
// TEXTURES //
//////////////

Texture2DArray shadowTexture : register(t0);

///////////////////
// SAMPLE STATES //
///////////////////

SamplerState shadowSampler : register(s0);

/////////////
// GLOBALS //
/////////////

cbuffer MatrixBuffer : register(b0) {
    matrix worldMatrix;               // Transform into world space
};

//////////////////////
// CONSTANT BUFFERS //
//////////////////////

struct CameraType {
    matrix view;                      // Camera view matrix
    matrix projection;                // Transform world to view space and project

	float4 position;                  // Camera position (eye)
	float4 target;                    // Camera target (lookAt)
    float4 direction;                // World space light direction
	float4 rotation;                  // Camera rotation

	float2 plane;                     // Clipping plane (near/far)
	float  FOV;                       // Field of view (radians)
    float  alpha;                     // View transparency (for fades)
};

cbuffer CameraBuffer : register(b1)
{
    CameraType camera;
};

struct LightType : CameraType {
    float4  colour;                   // Light color/intensity (RGBA)
    float2  linearNearFar;            // Distances to use for rescaling shadows
    float2  distanceFalloff;          // Light distance falloff
    float2  angleFalloff;             // Light angular falloff (relative to FOV)
};

#define NUM_LIGHTS 3

cbuffer LightBuffer : register(b2) {
    LightType light[NUM_LIGHTS];
};

cbuffer MaterialBuffer : register(b3)
{
    float4 ambientColour;
    float4 specularColour;
    float  specularPower;
};

#define PCF_RANGE 1
#define PCF_WEIGHTED 0
#if PCF_WEIGHTED
namespace SAT {
#define SAT_SIZE (PCF_RANGE * 2 + 1)

#define S0(Y) (Y) / (Y)
#define S1(Y) (Y-1.f) / (Y), S0(Y), (Y-1.f) / (Y) 
#define S2(Y) (Y-2.f) / (Y), S1(Y), (Y-2.f) / (Y) 
#define S3(Y) (Y-3.f) / (Y), S2(Y), (Y-3.f) / (Y)
#define S4(Y) (Y-4.f) / (Y), S3(Y), (Y-4.f) / (Y)

#define DEFINE_S(range) static const float S[SAT_SIZE] = { S##range(range+1) };

DEFINE_S(PCF_RANGE)

#define ROW0(Y) S[0] * S[Y] // 0
#define ROW1(Y) ROW0(Y), S[1] * S[Y], S[2] * S[Y] // 1
#define ROW2(Y) ROW1(Y), S[3] * S[Y], S[4] * S[Y] // 2
#define ROW3(Y) ROW2(Y), S[5] * S[Y], S[6] * S[Y] // 3
#define ROW4(Y) ROW3(Y), S[7] * S[Y], S[8] * S[Y] // 4

static const float TABLE[SAT_SIZE * SAT_SIZE] = {
//    ROW0(0) // 0
    ROW1(0), ROW1(1), ROW1(2) // 1
//    ROW2(0), ROW2(1), ROW2(2), ROW2(3), ROW2(4) // 2
//    ROW3(0), ROW3(1), ROW3(2), ROW3(3), ROW3(4), ROW3(5), ROW3(6) // 3
//    ROW4(0), ROW4(1), ROW4(2), ROW4(3), ROW4(4), ROW4(5), ROW4(6), ROW4(7), ROW4(8) // 4
};

};
#endif

//////////////
// TYPEDEFS //
//////////////
struct VertexInput
{
    float4 position : POSITION; // Object space position
    float4 colour : COLOR;      // Object colour
    float3 normal : NORMAL;     // Object space normal
};

struct PixelInput
{
    float4 position : SV_POSITION;      // Projection position
    float4 colour   : COLOR;            // Pixel colour
    float3 worldPosition : TEXCOORD0;   // Light position in world space
    float4 lightPosition[3] : TEXCOORD1;// Light position in view space
    float3 normal   : NORMAL;           // World space normal
};

//--------------------------------------------------------------------------------------
// Misc Utilities
//--------------------------------------------------------------------------------------

float scale(float min, float max, float value) {
    return clamp((value - min) / (max - min), 0.f, 1.f);
}

// Rescale into [0, 1]
float scale(const float2 range, float value) {
    return clamp((value - range.x) / (range.y - range.x), 0.f, 1.f);
}

float2 bias() {
    return float2(0.5, 0);
}

// Computes Chebyshev's Inequality
// Returns an upper bound given the first two moments and mean
float ChebyshevUpperBound(float2 moments, float mean)
{
    // Standard shadow map comparison
    float p = (mean <= moments.x);
    
    // Compute variance
    float variance = moments.y - (moments.x * moments.x);

    const float minVariance = 0.000005;
    variance = max(variance, minVariance);
    
    // Compute probabilistic upper bound
    float d     = mean - moments.x;
    float p_max = variance / (variance + d * d);
    
    return max(p, p_max);
}

// Light bleeding reduction
float LBR(float p)
{
    // Lots of options here if we don't care about being an upper bound.
    // Use whatever falloff function works well for your scene.
    const float amount = 0.5f;// 0.18f;
    return scale(amount, 1, p);
}

//--------------------------------------------------------------------------------------
// VSM using hardware filtering
//--------------------------------------------------------------------------------------

float3 LightShader(int light_index, float3 surfacePosition, float3 surfaceNormal,
                    out float distanceToLight, out float3 directionToLight) {

    directionToLight = light[light_index].position.xyz - surfacePosition;
    distanceToLight = length(directionToLight);
    directionToLight /= distanceToLight;

    float distanceAttenuation = 1.f, angleAttenuation = 1.f;

    [branch] if(light[light_index].distanceFalloff.y)
        distanceAttenuation = 1.f - scale(light[light_index].distanceFalloff, distanceToLight);

    [branch] if(light[light_index].angleFalloff.y) {
        const float angle = dot(directionToLight, light[light_index].direction.xyz);
        static const float2 angleFalloff[3] = { 
            cos(0.5f * light[0].FOV * light[0].angleFalloff),
            cos(0.5f * light[1].FOV * light[1].angleFalloff),
            cos(0.5f * light[2].FOV * light[2].angleFalloff) 
        };
        angleAttenuation = 1.f - scale(angleFalloff[light_index], angle);
    }

    const float attenuation = saturate(distanceAttenuation * angleAttenuation);
    return light[light_index].colour.rgb * attenuation;
}

float2 sampleShadowPCF(int light_index, const in float3 position){
#if !PCF_WEIGHTED
    // unweighted
    float2 shadow = 0.f;
    [unroll] for (int x = -PCF_RANGE; x <= PCF_RANGE; x ++) {
        [unroll] for (int y = -PCF_RANGE; y <= PCF_RANGE; y ++) {
            shadow += shadowTexture.Sample(shadowSampler, float3(position.xy, light_index), int2(x, y)).rg;
        }
    }  
    
    return shadow * 1.f / ((PCF_RANGE * 2 + 1) * (PCF_RANGE * 2 + 1));
#else
    // weighted
    float2 shadow = 0.f;
    int sat = 0;
    [unroll] for (int x = -PCF_RANGE; x <= PCF_RANGE; x ++) {
        [unroll] for (int y = -PCF_RANGE; y <= PCF_RANGE; y ++) {
            shadow += SAT::TABLE[sat++] * shadowTexture.Sample(shadowSampler, float3(position.xy, light_index), int2(x, y)).rg;
        }
    }  

    return shadow * 1.f / ((PCF_RANGE + 1) * (PCF_RANGE + 1));
#endif
}

float2 sampleShadow(int light_index, const in float2 position){
    //return shadowTexture.Load(position.x, position.y);
    return shadowTexture.Sample(shadowSampler, float3(position.xy, light_index)).rg;
}

//--------------------------------------------------------------------------------------
// Point sampling (standard shadow mapping)
//--------------------------------------------------------------------------------------
float4 LightShaderSM(int light_index, float3 surfacePosition, float3 surfaceNormal,
                         float2 shadowTextureCoord, out float distanceToLight, out float3 directionToLight)
{
    // Call parent
    const float bias = 0.006f;
    float3 illumination = LightShader(light_index, surfacePosition, surfaceNormal, distanceToLight, directionToLight);
    distanceToLight = scale(light[light_index].linearNearFar, distanceToLight) - bias;
    
    // Sample shadow map
    float depth = sampleShadow(light_index, shadowTextureCoord).r;
    float shadowAmount = (depth >= distanceToLight);
    
    return float4(illumination * shadowAmount, 1.f - shadowAmount);
}

float4 LightShaderVSM(int light_index, float3 surfacePosition, float3 surfaceNormal,
                          float3 shadowTextureCoord, out float distanceToLight, out float3 directionToLight)
{
    float3 illumination = LightShader(light_index, surfacePosition, surfaceNormal, distanceToLight, directionToLight);
    distanceToLight = scale(light[light_index].linearNearFar, distanceToLight);

    // outside of depth map
    if (saturate(shadowTextureCoord.x) != shadowTextureCoord.x 
     || saturate(shadowTextureCoord.y) != shadowTextureCoord.y
     || saturate(shadowTextureCoord.z) != shadowTextureCoord.z)
        return float4(0.f, 0.f, 0.f, 0.0f);

    float2 moments = sampleShadowPCF(light_index, shadowTextureCoord);
    float shadowAmount = ChebyshevUpperBound(moments, distanceToLight);

    shadowAmount = LBR(shadowAmount);
    
    return float4(illumination * shadowAmount, 1.f - shadowAmount);
}


float4 BasicShader(const float3 normal, const float3 directionToLight) {
    const float colour = saturate(dot(normal, directionToLight));
    return float4(colour, colour, colour, 1.f);
}

float4 BasicShader(const float3 normal, const float3 directionToLight, const float3 illumination) {
    const float3 colour = saturate(dot(normal, directionToLight)) * illumination;
    return float4(colour, 1.f);
}

float4 SurfaceShader(const float3 position,
                     const float3 normal,
                     const float3 directionToLight,
                     const float3 illumination) {

    // Blinn-Phong BRDF
    const float3 viewDirection = /*normalize*/(camera.position.xyz - position);
    
    float NdotL = dot(normal, directionToLight);
    float diffuseAmount = saturate(NdotL);

    float3 halfVector = normalize(directionToLight + viewDirection);
    float specularAmount = specularPower == 0.f ? 0.f : saturate(pow(max(0, dot(normal, halfVector)), specularPower));

    // Combine ambient, diffuse, specular and external attenuation
    return (float4(illumination * (diffuseAmount + saturate(specularColour.rgb * specularAmount)), 1.f));
}

////////////////////////////////////////////////////////////////////////////////
// Vertex Shader
////////////////////////////////////////////////////////////////////////////////

PixelInput vertexShader(VertexInput input) {

    PixelInput output;
    input.position.w = 1.0f;
    
    //output.position = mul(input.position, worldMatrix);
    output.position = mul(input.position, camera.projection);

    output.worldPosition =  mul(input.position, worldMatrix).xyz;

    //output.normal = mul(input.normal, (float3x3)worldMatrix);   // Assume orthogonal
    output.normal = normalize(input.normal);

    //output.lightPosition = mul(input.position, worldMatrix);
    output.lightPosition[0] = mul(input.position, light[0].projection);
    output.lightPosition[1] = mul(input.position, light[1].projection);
    output.lightPosition[2] = mul(input.position, light[2].projection);
    
    output.colour = input.colour;
    
    return output;
}

float4 pixelShader(PixelInput input) : SV_TARGET{
    float4 colour = 0;// float4(ambientColour.rgb* ambientColour.rgb* ambientColour.rgb, 1.0);

    for(int l = 0; l < 3; l++)
    { 
        float distanceToLight;
        float3 directionToLight;
            
        const float3 texCoord = float3((input.lightPosition[l].xy / input.lightPosition[l].w) * float2(0.5, -0.5) + 0.5,
                                        input.lightPosition[l].z / input.lightPosition[l].w);
        float4 illumination = LightShaderVSM(l, input.worldPosition, input.normal, texCoord, distanceToLight, directionToLight);
        if (input.colour.a) {
            float4 surface = SurfaceShader(input.worldPosition, input.normal, directionToLight, illumination.rgb);
            colour += float4(input.colour.rgb * (colour.rgb + surface.rgb), input.colour.a);
        }
        else {
            colour += float4(input.colour.rgb, illumination.a * 0.5 * (input.colour.a * 0.5 + 0.5));
        }
    }

    if (input.colour.a) {
        return float4(saturate(ambientColour.rgb + colour.rgb), input.colour.a * camera.alpha);
    } else {
        return float4(saturate(colour.rgb), colour.a * camera.alpha);
    }

    // Simple shader
    //colour += BasicShader(input.normal, light.direction, illumination);

    //// LIGHT ONLY

    //// Invert the light direction for calculations.
    //float3 lightDir = -light.direction;
    //lightDir = normalize(lightDir);

    //// Calculate the amount of light on this pixel.
    //float lightIntensity = saturate(dot(input.normal, lightDir));

    //// Multiply the texture pixel and the final diffuse color to get the final pixel color result.
    //return saturate(ambientColour * input.colour + diffuseColour * lightIntensity * input.colour);

}