import { Vector2, Vector3 } from 'three'
import { colorMapArray } from './ColorMap'

/**
 * Shaders to render 3D volumes using raycasting.
 * The applied techniques are based on similar implementations in the Visvis and Vispy projects.
 * This is not the only approach, therefore it's marked 1.
 */

export const LabelRenderShader = {
  uniforms: {
    uSize: { value: new Vector3(1, 1, 1) },
    uRenderStyle: { value: 0 },
    uRenderThreshold: { value: 0.5 },
    uClim: { value: new Vector2(1, 1) },
    uData: { value: null },
    uCmData: { value: null },
    uOpacity: { value: 0.5 },
    uLabelCount: { value: 0 },
    uColorMap: { value: new Float32Array(colorMapArray) }, // load colormap and convert float array
    maxColors: { value: colorMapArray.length / 3 }, // calculate the maximum number of colormap
    uLabelColorSwitch: { value: true }
  },

  vertexShader: /* glsl */ `

		varying vec4 vNearpos;
		varying vec4 vFarpos;
		varying vec3 vPosition;

		void main() {
			// Prepare transforms to map to "camera view". See also:
			// https://threejs.org/docs/#api/renderers/webgl/WebGLProgram
			mat4 viewTransformf = modelViewMatrix;
			mat4 viewTransformi = inverse(modelViewMatrix);

			// Project local vertex coordinate to camera position. Then do a step
			// backward (in cam coords) to the near clipping plane, and project back. Do
			// the same for the far clipping plane. This gives us all the information we
			// need to calculate the ray and truncate it to the viewing cone.
			vec4 position4 = vec4(position, 1.0);
			vec4 posInCam = viewTransformf * position4;

			// Intersection of ray and near clipping plane (z = -1 in clip coords)
			posInCam.z = -posInCam.w;
			vNearpos = viewTransformi * posInCam;

			// Intersection of ray and far clipping plane (z = +1 in clip coords)
			posInCam.z = posInCam.w;
			vFarpos = viewTransformi * posInCam;

			// Set varyings and output pos
			vPosition = position;
			gl_Position = projectionMatrix * viewMatrix * modelMatrix * position4; // variable name "gl_Position" should be not changed
		}`,

  /*
            The Fragement shader is responsible for the color of the bone.
            By applying lightnig to it in the isocast function this was achieved.
            Most of the code is copiedfrom the VolumeShader.jsx
            I recommend to create .glsl files of the respective shaders for coding,
            because that way it is easier to follow the code and find the declarations of the variables.


        */

  fragmentShader: /* glsl */ `

                precision highp float;
                precision mediump sampler3D;
                
                uniform vec3 uSize;
                uniform int uRenderStyle;
                uniform float uRenderThreshold;
                uniform vec2 uClim;
                uniform float uOpacity;
                uniform int uLabelCount;
                
                uniform sampler3D uData;
                uniform sampler2D uCmData;
                
                varying vec3 vPosition;
                varying vec4 vNearpos;
                varying vec4 vFarpos;

                uniform bool uLabelColorSwitch;

                
                // the size should be 3 times the number of colors(max : 256)
                // In GLSL, the OpenGL Shading Language, the size of an array declared as a uniform variable 
                // must be a compile-time constant, which means it cannot be set dynamically based on a uniform value.
                uniform float uColorMap[768];
                
                const float shininess = 1.0;
                
                uniform int maxColors;  // load maxColors what we define for maximum number of colors
                
                const float relativeStepSize = 1.0;
               
                float sample1(vec3 texCoords);

                void castIso(vec3 startLoc, vec3 step, int nSteps, vec3 viewRay);
                
                vec4 applyColorMap(float val, int label);
                
                vec4 addLighting(float val, vec3 loc, vec3 step, vec3 viewRay, int label);
                
                vec4 addLighting(float val, vec3 loc, vec3 step, vec3 viewRay, int label)
                {
                    // Calculate color by incorporating lighting
                
                    // View direction
                    vec3 V = normalize(viewRay);
                
                    // calculate normal vector from gradient
                    vec3 n;
                    float val1, val2;
                    val1 = sample1(loc + vec3(-step[0], 0.0, 0.0));
                    val2 = sample1(loc + vec3(+step[0], 0.0, 0.0));
                    n[0] = val1 - val2;
                    val = max(max(val1, val2), val);
                    val1 = sample1(loc + vec3(0.0, -step[1], 0.0));
                    val2 = sample1(loc + vec3(0.0, +step[1], 0.0));
                    n[1] = val1 - val2;
                    val = max(max(val1, val2), val);
                    val1 = sample1(loc + vec3(0.0, 0.0, -step[2]));
                    val2 = sample1(loc + vec3(0.0, 0.0, +step[2]));
                    n[2] = val1 - val2;
                    val = max(max(val1, val2), val);
                
                    float gm = length(n); // gradient magnitude
                    n = normalize(n);
                
                    // Flip normal so it points towards viewer
                    float nSelect = float(dot(n, V) > 0.0);
                    n = (2.0 * nSelect - 1.0) * n;	// ==	nSelect * n - (1.0-nSelect)*n;
                
                    // Init colors
                    vec4 ambientColor = vec4(0.2, 0.2, 0.2, 0.0);
                    vec4 diffuseColor = vec4(0.0, 0.0, 0.0, 0.0);
                    vec4 specularColor = vec4(0.05, 0.05, 0.05, 0.0);
                
                    // note: could allow multiple lights
                    for (int i=0; i<1; i++)
                    {
                            // Get light direction (make sure to prevent zero devision)
                            vec3 l = normalize(viewRay);	//lightDirs[i];
                            float lightEnabled = float( length(l) > 0.0 );
                            l = normalize(l + (1.0 - lightEnabled));
                
                            // Calculate lighting properties
                            float lambertTerm = clamp(dot(n, l), 0.0, 1.0);
                            vec3 h = normalize(l+V); // Halfway vector
                            float specularTerm = pow(max(dot(h, n), 0.0), shininess);
                
                            // Calculate mask
                            float mask1 = lightEnabled;


                            /*
                                The terms below determine most of the outcome and how it looks 
                                in 3D but the biggest impact is through the definition of
                                ambientColor, diffuseColor and specularColor above.
                                The shininess can also have some influence since it defines the power to 
                                which some term is increased.
                            */
                
                            // Calculate colors
                            ambientColor +=	mask1 * ambientColor * 0.0001;	// * gl_LightSource[i].ambient;
                            diffuseColor +=	mask1 * lambertTerm * 0.6;
                            specularColor += mask1 * specularTerm * specularColor * 0.000001;
                    }
                
                    // Calculate final color by componing different components
                    vec4 finalColor;
                    vec4 color = applyColorMap(val, label);
                    finalColor = color * (ambientColor + diffuseColor) + specularColor;
                    finalColor.a = color.a * uOpacity;
                    return finalColor;
                }
                
                vec4 applyColorMap(float val, int label){

                    // original value instead of 30.0: 7.5
    
                    // int index = int(val) * 3;
                    // return vec4(uColorMap[index]/255.0, uColorMap[index + 1]/255.0, uColorMap[index + 2]/255.0, 7.5);

                    // Check the labelColorSwitch
                    if (uLabelColorSwitch) {
                        int index = int(val) * 3;
                        return vec4(uColorMap[index]/255.0, uColorMap[index + 1]/255.0, uColorMap[index + 2]/255.0, 30.0);
                    } else {
                        // When labelColorSwitch is false
                        if (val == 0.0) {
                            return vec4(0.0, 0.0, 0.0, 30.0); // Black for index 0
                        } else {
                            return vec4(1.0, 1.0, 1.0, 30.0); // White otherwise
                        }
                    }
                
                }
                
                void castIso(vec3 startLoc, vec3 step, int nSteps, vec3 viewRay) {
                
                    gl_FragColor = vec4(0.0);	// init transparent
                    vec4 color3 = vec4(1.0);	// final color
                    vec3 dstep = 1.5 / uSize;	// step to sample derivative
                    vec3 loc = startLoc;
                
                    float lowThreshold = uRenderThreshold - 0.02 * (uClim[1] - uClim[0]);
                
                    int MAX_STEPS = int(uLabelCount);
                
                    // Enter the raycasting loop. In WebGL 1 the loop index cannot be compared with
                    // non-constant expression. So we use a hard-coded max, and an additional condition
                    // inside the loop.
                    for (int iter=0; iter<MAX_STEPS; iter++) {
                        if (iter >= nSteps)
                                break;
                
                        // Sample from the 3D texture
                        float val = sample1(loc);
                
                        if (val > lowThreshold) {
                            // match colormap into each label texture
                            for(int label = 1 ; label <= maxColors ; label++){
                                // only valid label texture can be colored
                                if(val == float(label)){
                                    gl_FragColor = addLighting(val, loc, dstep, viewRay, label);
                                    return;
                                }
                            }
                        }
                        // Advance location deeper into the volume
                        loc += step;
                    }
                }
                
                float sample1(vec3 texCoords) {
                    /* Sample float value from a 3D texture. Assumes intensity data. */
                    return texture(uData, texCoords.xyz).r;
                }
                
                void main() {
                    // Normalize clipping plane info
                    vec3 farPos = vFarpos.xyz / vFarpos.w;
                    vec3 nearPos = vNearpos.xyz / vNearpos.w;
                
                    // Calculate unit vector pointing in the view direction through this fragment.
                    vec3 viewRay = normalize(nearPos.xyz - farPos.xyz);
                
                    // Compute the (negative) distance to the front surface or near clipping plane.
                    // vPosition is the back face of the cuboid, so the initial distance calculated in the dot
                    // product below is the distance from near clip plane to the back of the cuboid
                    float distance = dot(nearPos - vPosition, viewRay);
                    distance = max(distance, min((-0.5 - vPosition.x) / viewRay.x,
                                                                            (uSize.x - 0.5 - vPosition.x) / viewRay.x));
                    distance = max(distance, min((-0.5 - vPosition.y) / viewRay.y,
                                                                            (uSize.y - 0.5 - vPosition.y) / viewRay.y));
                    distance = max(distance, min((-0.5 - vPosition.z) / viewRay.z,
                                                                            (uSize.z - 0.5 - vPosition.z) / viewRay.z));
                
                    // Now we have the starting position on the front surface
                    vec3 front = vPosition + viewRay * distance;
                
                    // Decide how many steps to take
                    int nSteps = int(-distance / relativeStepSize + 0.5);
                    if ( nSteps < 1 )
                            discard;
                
                    // Get starting location and step vector in texture coordinates
                    vec3 step = ((vPosition - front) / uSize) / float(nSteps);
                    vec3 startLoc = front / uSize;
                
                    // For testing: show the number of steps. This helps to establish
                    // whether the rays are correctly oriented
                    //'gl_FragColor = vec4(0.0, float(nSteps) / 1.0 / uSize.x, 1.0, 1.0);
                    //'return;
                
                    castIso(startLoc, step, nSteps, viewRay);
                
                    if (gl_FragColor.a < 0.05) // the variable name "gl_FragColor" should not be changed
                            discard;
                }`
}
