import npyjs from 'https://cdn.jsdelivr.net/npm/npyjs@0.6.0/+esm'

class Decompressor {
  constructor() {
    console.log('Decompressor instance created');
    this.decompressFnMap = {
      means: this._decompressPng16Bit.bind(this),
      scales: this._decompressPng.bind(this),
      quats: this._decompressPng.bind(this),
      opacities: this._decompressPng.bind(this),
      sh0: this._decompressPng.bind(this),
      shN: this._decompressKMeans.bind(this)
    };
  }

  _getDecompressFn(paramName) {
    return this.decompressFnMap[paramName];
    if (!fn) {
      console.warn(`No decompression function found for paramName: ${paramName}`);
    }
    // return this.decompressFnMap[paramName] || this._decompressNpz;
  }

  async decompress(compressDir) {
    // const metaResponse = await fetch(`${compressDir}/meta.json`);
    const meta = compressDir['meta.json']
    // console.log('Meta json', meta);
    const splats = {};
    for (const [paramName, paramMeta] of Object.entries(meta)) {
        // console.log(`Processing paramName: ${paramName}`);
        const decompressFn = this._getDecompressFn(paramName);
        if (!decompressFn) {
          console.error(`No decompression function found for paramName: ${paramName}`);
        }
      splats[paramName] = await decompressFn(compressDir, paramName, paramMeta);
    }

    // Param-specific postprocessing
    splats.means = this.inverseLogTransform(splats.means);
    return splats;
  }
  async _decompressPng(compressDir, paramName, meta) {
    if (!meta.shape.every((dim) => dim > 0)) {
        return tf.zeros(meta.shape, meta.dtype);
    }

    // console.log('Meta shape:', meta.shape);
    // console.log('Meta mins:', meta.mins);
    // console.log('Meta maxs:', meta.maxs);
    const tmp_i = compressDir[`${paramName}.png`];  // For 'means_l.png'

    const imageURL = URL.createObjectURL(tmp_i);
    let image = await this._loadImage(imageURL, paramName);

    // Check if imgL and imgU are tensors, and convert to arrays if necessary
    if (image instanceof tf.Tensor) {
        image = image.dataSync();  // Get the underlying array from the tensor
    }

    // Normalize image data
    const imgData = new Float32Array(image.length);
    for (let i = 0; i < image.length; i++) {
        imgData[i] = (image[i] / (2 ** 8 - 1));  // Convert to [0, 1] range
    }

    // Convert mins and maxs to Float32 arrays
    const mins = new Float32Array(meta.mins);
    const maxs = new Float32Array(meta.maxs);

    // Apply min-max scaling to each pixel
    const params = imgData.map((val, idx) => val * (maxs[idx % maxs.length] - mins[idx % mins.length]) + mins[idx % mins.length]);
    // console.log("imgData shape:", params.length);

    // Initialize reshapedParams variable
    let reshapedParams;

    // Check if the paramName is 'sh0', and reshape accordingly
    if (paramName === 'sh0') {
        reshapedParams = this._reshape(params, [meta.shape[0], 1, 3], paramName);  // Reshape to [N, 1, 3]
    } else if (paramName === 'opacities') {
        reshapedParams = this._reshape(params, [meta.shape[0], 1], paramName);  // Reshape for opacities (1 channel)
    } else {
        reshapedParams = this._reshape(params, [meta.shape[0], meta.shape[1]], paramName);  // Reshape for other cases (height x width)
    }

    return reshapedParams;  // Return reshaped parameters
  }

  async _decompressPng16Bit(compressDir, paramName, meta) {
      // Check if the shape is valid
      if (!meta.shape.every((dim) => dim > 0)) {
        return new Float32Array(meta.shape.fill(0));
      }

      // console.log('Meta shape:', meta.shape);
      // console.log('Meta mins:', meta.mins);
      // console.log('Meta maxs:', meta.maxs);
      const image_l = compressDir['means_l.png']; // Assuming 'compressDir' holds the Blob
      const image_u = compressDir['means_u.png']; // Same for the other image

      const imageURL_l = URL.createObjectURL(image_l);
      const imageURL_u = URL.createObjectURL(image_u);
      // Load the lower (imgL) and upper (imgU) 8-bit images
      let imgL = await this._loadImage(imageURL_l, paramName);
      let imgU = await this._loadImage(imageURL_u, paramName);

      // Check if imgL and imgU are tensors, and convert to arrays if necessary
      if (imgL instanceof tf.Tensor) {
        // console.log("imgL is a tensor, converting to array...");
        imgL = imgL.dataSync();  // Get the underlying array from the tensor
      }

      if (imgU instanceof tf.Tensor) {
        // console.log("imgU is a tensor, converting to array...");
        imgU = imgU.dataSync();  // Get the underlying array from the tensor
      }

      // Combine the lower and upper parts of the 16-bit data
      const imgData = new Float32Array(imgL.length);  // Initialize a Float32Array to store the combined data
      for (let i = 0; i < imgL.length; i++) {
        imgData[i] = ((imgU[i] << 8) + imgL[i]) / (2 ** 16 - 1);  // Combine the values to get the 16-bit data
      }
      // Convert mins and maxs to Float32 arrays
      const mins = new Float32Array(meta.mins);
      const maxs = new Float32Array(meta.maxs);

      // Apply min-max scaling to each pixel
      const params = imgData.map((val, idx) => val * (maxs[idx % maxs.length] - mins[idx % mins.length]) + mins[idx % mins.length]);
      // console.log('Scaled parameters:', params);
      // console.log("imgData shape:", params.length);
      const expectedLength = meta.shape[0] * 3;
      if (params.length !== expectedLength) {
        console.error(`Expected length ${expectedLength}, but got ${params.length}.`);
      }
      // Reshape the parameters as [n, 3] (where n is the total number of pixels and 3 is the number of channels)
      const reshapedParams = this._reshape(params, [meta.shape[0], 3], paramName); // Assuming height and width, with 3 channels
      // console.log('Final reshaped parameters:', reshapedParams);

      return reshapedParams;  // Return the array version of the reshaped tensor
  }

  // async _decompressNpz(compressDir, paramName, meta) {
  //   const response = await fetch(`${compressDir}/${paramName}.npz`);
  //   const arrayBuffer = await response.arrayBuffer();
  //   const npzData = new Uint8Array(arrayBuffer);
  //
  //   const params = Float32Array.from(npzData);
  //   return this._reshape(params, meta.shape);
  // }

  async _decompressKMeans(compressDir, paramName, meta) {
    // Validate the shape
    if (!meta.shape.every((dim) => dim > 0)) {
      return new Float32Array(meta.shape.fill(0)); // Return an empty array if the shape is invalid
    }

    try {
      // Fetch the .npz file from the given directory
      const npy = new npyjs();
      // const response = await fetch(`${compressDir}/${paramName}.npz`);
      // const arrayBuffer = await response.arrayBuffer();
      const arrayBuffer = compressDir['shN.npz'];
      // console.log("Fetched .npz file successfully");
      // Load the .npz file as a ZIP archive
      const zip = await JSZip.loadAsync(arrayBuffer);
      // console.log("Loaded .npz file as ZIP archive");

      // Load centroids and labels
      const [centroidsBuffer, labelsBuffer] = await Promise.all([
        zip.files['centroids.npy'].async('arraybuffer'),
        zip.files['labels.npy'].async('arraybuffer')
      ]);


      // Parse centroids and labels using npyjs
      const centroidsParsed = await npy.parse(centroidsBuffer);
      const labelsParsed = await npy.parse(labelsBuffer);

      // Assuming centroids are stored as uint8, but you can adjust this depending on the data type
      const centroidsData = new Uint8Array(centroidsParsed.data); // Use Float32Array for precision
      const labels = new Uint16Array(labelsParsed.data); // Labels as Uint16Array
      // console.log('Centroids:', centroidsData);
      // console.log('Labels:', labels);

      // Extract meta information
      const { quantization, mins, maxs } = meta;
      // console.log("Mins:", mins);
      // console.log("Maxs:", maxs);

      // Normalize centroids using quantization
      const quantizationFactor = 2 ** quantization - 1;
      const centroidsNormalized = Array.from(centroidsData).map(
        (value) => value / quantizationFactor
      );
      // console.log("Quantization Factor:", quantizationFactor);
      // console.log("Centroids Normalized Sample:", centroidsNormalized.slice(0, 5)); // Show first few for inspection

      // Reshape centroids into [65536, 45] (reshape according to shape)
      const centroidsReshaped = [];
      const numRows = centroidsParsed.shape[0]; // 65536 (assuming the shape is [65536, 45])
      const numCols = centroidsParsed.shape[1]; // 45 (assuming the second dimension is 45)

      for (let i = 0; i < numCols; i++) {
        centroidsReshaped.push(centroidsNormalized.slice(i * numRows, (i + 1) * numRows));
      }
      const transposedCentroids = this.transpose2DArray(centroidsReshaped);
      // console.log(transposedCentroids);

      // Rescale centroids to the original range
      const centroidsScaled = transposedCentroids.map((row) =>
        row.map((value) => value * (maxs - mins) + mins) // Rescaling formula
      );
      // console.log("Sample Rescaled Centroid Row:", centroidsScaled[0]);

      // Map each label to the corresponding centroid
      const params = Array.from(labels).map((label) => centroidsScaled[label]);
      // console.log("Mapped Centroids for Labels:", params.slice(0, 5)); // Show first few for inspection

      // Reshape the resulting array based on the meta.shape
      return this._reshape(params, meta.shape, paramName);

    } catch (error) {
      // console.error("Error during KMeans decompression:", error);
      return new Float32Array(meta.shape.fill(0)); // Return empty array on error
    }
  }

  inverseLogTransform(data) {
    return data.map(row =>
        row.map(val => Math.sign(val) * Math.expm1(Math.abs(val)))
    );
  }

  _reshape(array, shape, paraname) {
    let reshaped = [];
    let index = 0;

    // Handle sh0 case
    if (paraname === 'sh0') {
      const [totalPixels, , depth] = shape;
      for (let i = 0; i < totalPixels; i++) {
        const slice = [];
        for (let j = 0; j < depth; j++) {
          if (index < array.length) {
            slice.push(array[index++]);
          } else {
            slice.push(undefined);  // Fill with undefined if not enough data
          }
        }
        reshaped.push([slice]);  // Wrap the slice in an array
      }
    }
    // Handle shN case (reshaping to [n, 15, 3] with identical values in each row)
    else if (paraname.startsWith('shN')) {
      const [totalPixels, c, depth] = shape;

      // Verify the data length matches the expected size
      if (array.length !== totalPixels) {
        throw new Error(
          `Data length mismatch! Expected ${totalPixels}, but got ${array.length}.`
        );
      }

      // Reshape the data
      for (let i = 0; i < totalPixels; i++) {
        const flatArray = array[i];  // Get the flat array for each data point
        const reshapedPoint = [];

        // Reshape flat array into a 2D array of shape [c, depth]
        for (let row = 0; row < c; row++) {
          reshapedPoint.push(flatArray.slice(row * depth, (row + 1) * depth));
        }

        reshaped.push(reshapedPoint);
      }
    }
    // Handle default reshaping
    else {
      const [totalPixels, depth] = shape;
      for (let i = 0; i < totalPixels; i++) {
        const slice = [];
        for (let j = 0; j < depth; j++) {
          if (index < array.length) {
            slice.push(array[index++]);
          } else {
            slice.push(undefined);  // Fill with undefined if not enough data
          }
        }
        reshaped.push(slice);
      }
    }

    // Corrected console log
    // console.log('First slice:', reshaped[0]);
    // console.log('First row in the first slice:', reshaped[0][0]);
    // console.log(`Reshaped ${paraname} to shape:`, shape);
    return reshaped;
  }

  async _loadImage(url, paramName) {
    let numChannels = 1;  // Default to 1 channel (grayscale)

    // Set the number of channels based on paramName
    if (paramName === 'scales' || paramName === 'sh0' || paramName === 'means') {
      numChannels = 3;  // Expect 3 channels (RGB)
    } else if (paramName === 'quats') {
      numChannels = 4;  // Expect 4 channels (RGBA)
    }

    const img = await new Promise((resolve, reject) => {
      const image = new Image();
      image.onload = () => resolve(image);
      image.onerror = reject;
      image.src = url;
    });

    const canvas = document.createElement('canvas');
    canvas.width = img.width;
    canvas.height = img.height;

    const ctx = canvas.getContext('2d');
    ctx.drawImage(img, 0, 0);

    const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
    const data = new Uint8Array(imageData.data);  // Image data is in RGBA format by default

    // // Log the data for debugging purposes
    // console.log("Loaded image data:", data);

    // // Debug the number of channels in the loaded image
    const channels = numChannels;
    // console.log(`Expected number of channels for ${paramName}:`, channels);

    const width = canvas.width;
    const height = canvas.height;

    // Initialize reshaped data array based on the number of channels
    const reshapedData = new Float32Array(width * height * channels);

    for (let i = 0; i < width * height; i++) {
      const idx = i * 4; // The original image data contains 4 channels (RGBA)

      if (channels === 1) {
        // For grayscale, use the average of RGBA or just the red channel
        reshapedData[i] = data[idx]; // Normalize the grayscale channel
      } else if (channels === 3) {
        // For RGB, use the first three channels (R, G, B)
        reshapedData[i * 3] = data[idx];       // R
        reshapedData[i * 3 + 1] = data[idx + 1]; // G
        reshapedData[i * 3 + 2] = data[idx + 2]; // B
      } else if (channels === 4) {
        // For RGBA, use all four channels
        reshapedData[i * 4] = data[idx];     // R
        reshapedData[i * 4 + 1] = data[idx + 1]; // G
        reshapedData[i * 4 + 2] = data[idx + 2]; // B
        reshapedData[i * 4 + 3] = data[idx + 3]; // A
      }
    }

    return tf.tensor(reshapedData, [height, width, channels]);
  }
  transpose2DArray = (array) => {
      const numRows = array.length; // 原始行數
      const numCols = array[0].length; // 原始列數
      const transposed = Array.from({ length: numCols }, () => Array(numRows).fill(0));

      for (let i = 0; i < numRows; i++) {
          for (let j = 0; j < numCols; j++) {
              transposed[j][i] = array[i][j];
          }
      }

      return transposed;
  };
}

export default Decompressor;