import { Vector2, Vector3 } from 'three'

/**
 * Shaders to render 3D volumes using raycasting.
 * The applied techniques are based on similar implementations in the Visvis and Vispy projects.
 * This is not the only approach, therefore it's marked 1.
 */

export const VolumeRenderShader = {
  uniforms: {
    uSize: { value: new Vector3(1, 1, 1) },
    uRenderStyle: { value: 0 },
    uRenderThreshold: { value: 0.5 },
    uMaxRenderThreshold: { value: 100.0 },
    uClim: { value: new Vector2(1, 1) },
    uData: { value: null },
    uCmData: { value: null },
    uOpacity: { value: 0.5 }
  },

  vertexShader: /* glsl */ `

		varying vec4 vNearpos;
		varying vec4 vFarpos;
		varying vec3 vPosition;

		void main() {
			// Prepare transforms to map to "camera view". See also:
			// https://threejs.org/docs/#api/renderers/webgl/WebGLProgram
			mat4 viewTransformf = modelViewMatrix;
			mat4 viewTransformi = inverse(modelViewMatrix);

			// Project local vertex coordinate to camera position. Then do a step
			// backward (in cam coords) to the near clipping plane, and project back. Do
			// the same for the far clipping plane. This gives us all the information we
			// need to calculate the ray and truncate it to the viewing cone.
			vec4 position4 = vec4(position, 1.0);
			vec4 posInCam = viewTransformf * position4;

			// Intersection of ray and near clipping plane (z = -1 in clip coords)
			posInCam.z = -posInCam.w;
			vNearpos = viewTransformi * posInCam;

			// Intersection of ray and far clipping plane (z = +1 in clip coords)
			posInCam.z = posInCam.w;
			vFarpos = viewTransformi * posInCam;

			// Set varyings and output pos
			vPosition = position;
			gl_Position = projectionMatrix * viewMatrix * modelMatrix * position4; // variable name "gl_Position" should be not changed
		}`,

  fragmentShader: /* glsl */ `

				precision highp float;
				precision mediump sampler3D;

				uniform vec3 uSize;
				uniform int uRenderStyle;
				uniform float uRenderThreshold;
				uniform float uMaxRenderThreshold;
				uniform vec2 uClim;
				uniform float uOpacity;

				uniform sampler3D uData;
				uniform sampler2D uCmData;

				varying vec3 vPosition;
				varying vec4 vNearpos;
				varying vec4 vFarpos;

				// The maximum distance through our rendering volume is sqrt(3).
				const int MAX_STEPS = 887;	// 887 for 512^3, 1774 for 1024^3
				const int REFINEMENT_STEPS = 4;
				const float relativeStepSize = 1.0;
				const vec4 ambientColor = vec4(0.2, 0.4, 0.2, 1.0);
				const vec4 diffuseColor = vec4(0.8, 0.2, 0.2, 1.0);
				const vec4 specularColor = vec4(1.0, 1.0, 1.0, 1.0);
				const float shininess = 40.0;

				void castMip(vec3 startLoc, vec3 step, int nSteps, vec3 viewRay);
				void castIso(vec3 startLoc, vec3 step, int nSteps, vec3 viewRay);

				float sample1(vec3 texCoords);
				vec4 applyColorMap(float val);
				vec4 addLighting(float val, vec3 loc, vec3 step, vec3 viewRay);

				vec4 addLighting(float val, vec3 loc, vec3 step, vec3 viewRay)
				{
					// Calculate color by incorporating lighting

					// View direction
					vec3 V = normalize(viewRay);

					// calculate normal vector from gradient
					vec3 n;
					float val1, val2;
					val1 = sample1(loc + vec3(-step[0], 0.0, 0.0));
					val2 = sample1(loc + vec3(+step[0], 0.0, 0.0));
					n[0] = val1 - val2;
					val = max(max(val1, val2), val);
					val1 = sample1(loc + vec3(0.0, -step[1], 0.0));
					val2 = sample1(loc + vec3(0.0, +step[1], 0.0));
					n[1] = val1 - val2;
					val = max(max(val1, val2), val);
					val1 = sample1(loc + vec3(0.0, 0.0, -step[2]));
					val2 = sample1(loc + vec3(0.0, 0.0, +step[2]));
					n[2] = val1 - val2;
					val = max(max(val1, val2), val);

					float gm = length(n); // gradient magnitude
					n = normalize(n);

					// Flip normal so it points towards viewer
					float nSelect = float(dot(n, V) > 0.0);
					n = (2.0 * nSelect - 1.0) * n;	// ==	nSelect * n - (1.0-nSelect)*n;

					// Init colors
					vec4 ambientColor = vec4(0.0, 0.0, 0.0, 0.0);
					vec4 diffuseColor = vec4(0.0, 0.0, 0.0, 0.0);
					vec4 specularColor = vec4(0.0, 0.0, 0.0, 0.0);

					// note: could allow multiple lights
					for (int i=0; i<1; i++)
					{
							 // Get light direction (make sure to prevent zero devision)
							vec3 l = normalize(viewRay);	//lightDirs[i];
							float lightEnabled = float( length(l) > 0.0 );
							l = normalize(l + (1.0 - lightEnabled));

							// Calculate lighting properties
							float lambertTerm = clamp(dot(n, l), 0.0, 1.0);
							vec3 h = normalize(l+V); // Halfway vector
							float specularTerm = pow(max(dot(h, n), 0.0), shininess);

							// Calculate mask
							float mask1 = lightEnabled;

							// Calculate colors
							ambientColor +=	mask1 * ambientColor;	// * gl_LightSource[i].ambient;
							diffuseColor +=	mask1 * lambertTerm;
							specularColor += mask1 * specularTerm * specularColor;
					}

					// Calculate final color by componing different components
					vec4 finalColor;
					vec4 color = applyColorMap(val);
					finalColor = color * (ambientColor + diffuseColor) + specularColor;
					finalColor.a = color.a * uOpacity;
					return finalColor;
				}

				vec4 applyColorMap(float val) {
					val = (val - uClim[0]) / (uClim[1] - uClim[0]);
					return texture2D(uCmData, vec2(val, 0.5));
				}

				void castIso(vec3 startLoc, vec3 step, int nSteps, vec3 viewRay) {

					gl_FragColor = vec4(0.0);	// init transparent
					vec4 color3 = vec4(1.0);	// final color
					vec3 dstep = 1.5 / uSize;	// step to sample derivative
					vec3 loc = startLoc;

					float lowThreshold = uRenderThreshold - 0.02 * (uClim[1] - uClim[0]);
					float maxThreshold = uMaxRenderThreshold + 0.02 * (uClim[1] - uClim[0]);

					// Enter the raycasting loop. In WebGL 1 the loop index cannot be compared with
					// non-constant expression. So we use a hard-coded max, and an additional condition
					// inside the loop.
					for (int iter=0; iter<MAX_STEPS; iter++) {
							if (iter >= nSteps)
									break;

							// Sample from the 3D texture
							float val = sample1(loc);

							if (val > lowThreshold && val < maxThreshold) {
									// Take the last interval in smaller steps
									vec3 iLoc = loc - 0.5 * step;
									vec3 iStep = step / float(REFINEMENT_STEPS);
									for (int i=0; i<REFINEMENT_STEPS; i++) {
											val = sample1(iLoc);
											if (val > uRenderThreshold && val < uMaxRenderThreshold) {
													gl_FragColor = addLighting(val, iLoc, dstep, viewRay);
													return;
											}
											iLoc += iStep;
									}
							}

							// Advance location deeper into the volume
							loc += step;
					}
				}

				void castMip(vec3 startLoc, vec3 step, int nSteps, vec3 viewRay) {

					float maxVal = -1e6;
					int maxI = 100;
					vec3 loc = startLoc;

					// Enter the raycasting loop. In WebGL 1 the loop index cannot be compared with
					// non-constant expression. So we use a hard-coded max, and an additional condition
					// inside the loop.
					for (int iter=0; iter<MAX_STEPS; iter++) {
							if (iter >= nSteps)
									break;
							// Sample from the 3D texture
							float val = sample1(loc);
							// Apply MIP operation
							if (val > maxVal) {
									maxVal = val;
									maxI = iter;
							}
							// Advance location deeper into the volume
							loc += step;
					}

					// Refine location, gives crispier images
					vec3 iLoc = startLoc + step * (float(maxI) - 0.5);
					vec3 iStep = step / float(REFINEMENT_STEPS);
					for (int i=0; i<REFINEMENT_STEPS; i++) {
							maxVal = max(maxVal, sample1(iLoc));
							iLoc += iStep;
					}

					// Resolve final color
					gl_FragColor = applyColorMap(maxVal);
				}

				float sample1(vec3 texCoords) {
					/* Sample float value from a 3D texture. Assumes intensity data. */
					return texture(uData, texCoords.xyz).r;
				}

				void main() {
					// Normalize clipping plane info
					vec3 farPos = vFarpos.xyz / vFarpos.w;
					vec3 nearPos = vNearpos.xyz / vNearpos.w;

					// Calculate unit vector pointing in the view direction through this fragment.
					vec3 viewRay = normalize(nearPos.xyz - farPos.xyz);

					// Compute the (negative) distance to the front surface or near clipping plane.
					// vPosition is the back face of the cuboid, so the initial distance calculated in the dot
					// product below is the distance from near clip plane to the back of the cuboid
					float distance = dot(nearPos - vPosition, viewRay);
					distance = max(distance, min((-0.5 - vPosition.x) / viewRay.x,
																			(uSize.x - 0.5 - vPosition.x) / viewRay.x));
					distance = max(distance, min((-0.5 - vPosition.y) / viewRay.y,
																			(uSize.y - 0.5 - vPosition.y) / viewRay.y));
					distance = max(distance, min((-0.5 - vPosition.z) / viewRay.z,
																			(uSize.z - 0.5 - vPosition.z) / viewRay.z));

					// Now we have the starting position on the front surface
					vec3 front = vPosition + viewRay * distance;

					// Decide how many steps to take
					int nSteps = int(-distance / relativeStepSize + 0.5);
					if ( nSteps < 1 )
							discard;

					// Get starting location and step vector in texture coordinates
					vec3 step = ((vPosition - front) / uSize) / float(nSteps);
					vec3 startLoc = front / uSize;

					// For testing: show the number of steps. This helps to establish
					// whether the rays are correctly oriented
					//'gl_FragColor = vec4(0.0, float(nSteps) / 1.0 / uSize.x, 1.0, 1.0);
					//'return;

					if (uRenderStyle == 0)
							castMip(startLoc, step, nSteps, viewRay);
					else if (uRenderStyle == 1)
							castIso(startLoc, step, nSteps, viewRay);

					if (gl_FragColor.a < 0.05) // the variable name "gl_FragColor" should not be changed
							discard;
				}`
}
