import npyjs from 'https://cdn.jsdelivr.net/npm/npyjs@0.6.0/+esm'

class Decompressor {
  constructor() {
      console.log('Decompressor instance created');
      this.decompressFnMap = {
        means: this._decompressPng16Bit.bind(this),
        scales: this._decompressPng.bind(this),
        quats: this._decompressPng.bind(this),
        opacities: this._decompressPng.bind(this),
        sh0: this._decompressPng.bind(this),
        shN: this._decompressKMeans.bind(this),
      };
    }

  _getDecompressFn(paramName) {
      const fn = this.decompressFnMap[paramName];
      if (!fn) {
        console.warn(`No decompression function found for paramName: ${paramName}`);
      }
      return fn;
    }

  async decompress(compressDir) {
      const meta = compressDir['meta.json']; // Load metadata
      const splats = {};

      for (const [paramName, paramMeta] of Object.entries(meta)) {
        const decompressFn = this._getDecompressFn(paramName);

        if (!decompressFn) {
          console.error(`No decompression function found for paramName: ${paramName}`);
          continue; // Skip parameters without a decompression function
        }

        // Decompress the parameter and store it
        splats[paramName] = await decompressFn(compressDir, paramName, paramMeta);
      }

      // Param-specific postprocessing
      if (splats.means) {
        splats.means = this.inverseLogTransform(splats.means);
      }

      return splats;
    }

    async _decompressPng(compressDir, paramName, meta) {
      // Validate the shape
      if (!meta.shape.every((dim) => dim > 0)) {
        return tf.zeros(meta.shape, meta.dtype);
      }

      const tmp_i = compressDir[`${paramName}.png`];
      const imageURL = URL.createObjectURL(tmp_i);

      // Load the image
      let image = await this._loadImage(imageURL, paramName);

      // Ensure data is in Uint8Array format
      if (image instanceof tf.Tensor) {
        image = image.dataSync(); // Returns Uint8Array directly
      }

      // Constants and data setup
      const mins = new Float32Array(meta.mins);
      const maxs = new Float32Array(meta.maxs);
      const normalizationFactor = 1 / 255; // Normalize 8-bit to [0, 1]
      const batchSize = Math.min(100000, image.length); // Dynamically set batch size
      const imgData = new Float32Array(image.length);

      // Process image in batches
      for (let start = 0; start < image.length; start += batchSize) {
        const end = Math.min(start + batchSize, image.length);

        for (let i = start; i < end; i++) {
          const normalizedValue = image[i] * normalizationFactor;
          const idx = i % mins.length; // Modulo for multi-channel scaling
          imgData[i] = normalizedValue * (maxs[idx] - mins[idx]) + mins[idx];
        }
      }

      // Release resources
      image = null; // Explicitly release memory
      URL.revokeObjectURL(imageURL); // Revoke the URL

      // Reshape parameters based on the specific parameter name
      let reshapedParams;
      switch (paramName) {
        case 'sh0':
          reshapedParams = this._reshape(imgData, [meta.shape[0], 1, 3], paramName);
          break;
        case 'opacities':
          reshapedParams = this._reshape(imgData, [meta.shape[0], 1], paramName);
          break;
        default:
          reshapedParams = this._reshape(imgData, [meta.shape[0], meta.shape[1]], paramName);
      }

      return reshapedParams;
    }

  async _decompressPng16Bit(compressDir, paramName, meta) {
    // Check if the shape is valid
    if (!meta.shape.every((dim) => dim > 0)) {
      return new Float32Array(meta.shape.fill(0));
    }

    const image_l = compressDir['means_l.png'];
    const image_u = compressDir['means_u.png'];

    const imageURL_l = URL.createObjectURL(image_l);
    const imageURL_u = URL.createObjectURL(image_u);

    // Load the lower (imgL) and upper (imgU) 8-bit images
    let imgL = await this._loadImage(imageURL_l, paramName);
    let imgU = await this._loadImage(imageURL_u, paramName);

    // Convert tensors to data arrays if necessary
    if (imgL instanceof tf.Tensor) {
      imgL = imgL.dataSync(); // Avoid creating a new copy
    }

    if (imgU instanceof tf.Tensor) {
      imgU = imgU.dataSync(); // Avoid creating a new copy
    }

    // Combine the lower and upper parts of the 16-bit data
    const imgData = new Float32Array(imgL.length);
    for (let i = 0; i < imgL.length; i++) {
      imgData[i] = ((imgU[i] << 8) + imgL[i]) / (2 ** 16 - 1);
    }

    const mins = new Float32Array(meta.mins); // Avoid redundant conversion if meta.mins is already Float32Array
    const maxs = new Float32Array(meta.maxs);

    // Apply min-max scaling in-place to avoid creating intermediate arrays
    for (let i = 0; i < imgData.length; i++) {
      const idx = i % mins.length;
      imgData[i] = imgData[i] * (maxs[idx] - mins[idx]) + mins[idx];
    }

    const expectedLength = meta.shape[0] * 3;
    if (imgData.length !== expectedLength) {
      console.error(`Expected length ${expectedLength}, but got ${imgData.length}.`);
    }

    // Reshape the parameters as [n, 3]
    const reshapedParams = this._reshape(imgData, [meta.shape[0], 3], paramName);

    return reshapedParams;
  }

  async _decompressKMeans(compressDir, paramName, meta) {
    // Validate the shape
    if (!meta.shape.every((dim) => dim > 0)) {
      return new Float32Array(meta.shape.fill(0)); // Return an empty array if the shape is invalid
    }

    try {
      // Fetch the .npz file from the given directory
      const npy = new npyjs();
      // const response = await fetch(`${compressDir}/${paramName}.npz`);
      // const arrayBuffer = await response.arrayBuffer();
      const arrayBuffer = compressDir['shN.npz'];
      // console.log("Fetched .npz file successfully");
      // Load the .npz file as a ZIP archive
      const zip = await JSZip.loadAsync(arrayBuffer);
      // console.log("Loaded .npz file as ZIP archive");

      // Load centroids and labels
      const [centroidsBuffer, labelsBuffer] = await Promise.all([
        zip.files['centroids.npy'].async('arraybuffer'),
        zip.files['labels.npy'].async('arraybuffer')
      ]);


      // Parse centroids and labels using npyjs
      const centroidsParsed = await npy.parse(centroidsBuffer);
      const labelsParsed = await npy.parse(labelsBuffer);

      // Assuming centroids are stored as uint8, but you can adjust this depending on the data type
      const centroidsData = new Uint8Array(centroidsParsed.data); // Use Float32Array for precision
      const labels = new Uint16Array(labelsParsed.data); // Labels as Uint16Array
      // console.log('Centroids:', centroidsData);
      // console.log('Labels:', labels);

      // Extract meta information
      const { quantization, mins, maxs } = meta;
      // console.log("Mins:", mins);
      // console.log("Maxs:", maxs);

      // Normalize centroids using quantization
      const quantizationFactor = 2 ** quantization - 1;
      const centroidsNormalized = Array.from(centroidsData).map(
        (value) => value / quantizationFactor
      );
      // console.log("Quantization Factor:", quantizationFactor);
      // console.log("Centroids Normalized Sample:", centroidsNormalized.slice(0, 5)); // Show first few for inspection

      // Reshape centroids into [65536, 45] (reshape according to shape)
      const centroidsReshaped = [];
      const numRows = centroidsParsed.shape[0]; // 65536 (assuming the shape is [65536, 45])
      const numCols = centroidsParsed.shape[1]; // 45 (assuming the second dimension is 45)

      for (let i = 0; i < numCols; i++) {
        centroidsReshaped.push(centroidsNormalized.slice(i * numRows, (i + 1) * numRows));
      }
      const transposedCentroids = this.transpose2DArray(centroidsReshaped);
      // console.log(transposedCentroids);

      // Rescale centroids to the original range
      const centroidsScaled = transposedCentroids.map((row) =>
        row.map((value) => value * (maxs - mins) + mins) // Rescaling formula
      );
      // console.log("Sample Rescaled Centroid Row:", centroidsScaled[0]);

      // Map each label to the corresponding centroid
      const params = Array.from(labels).map((label) => centroidsScaled[label]);
      // console.log("Mapped Centroids for Labels:", params.slice(0, 5)); // Show first few for inspection

      // Reshape the resulting array based on the meta.shape
      return this._reshape(params, meta.shape, paramName);

    } catch (error) {
      // console.error("Error during KMeans decompression:", error);
      return new Float32Array(meta.shape.fill(0)); // Return empty array on error
    }
  }

  inverseLogTransform(data) {
    return data.map(row =>
        row.map(val => Math.sign(val) * Math.expm1(Math.abs(val)))
    );
  }

  _reshape(array, shape, paraname) {
    let reshaped;
    let index = 0;

    // Handle sh0 case
    if (paraname === 'sh0') {
      const [totalPixels, , depth] = shape;
      reshaped = new Array(totalPixels);

      // Flattening the structure and directly assigning the slices in a more memory-efficient way
      for (let i = 0; i < totalPixels; i++) {
        const slice = new Float32Array(depth);  // Use typed array to save memory
        for (let j = 0; j < depth; j++) {
          slice[j] = index < array.length ? array[index++] : undefined;
        }
        reshaped[i] = [slice];
      }
    }
    // Handle shN case
    else if (paraname.startsWith('shN')) {
      const [totalPixels, c, depth] = shape;

      if (array.length !== totalPixels) {
        throw new Error(
          `Data length mismatch! Expected ${totalPixels}, but got ${array.length}.`
        );
      }

      reshaped = new Array(totalPixels);
      for (let i = 0; i < totalPixels; i++) {
        const flatArray = array[i];
        const reshapedPoint = new Array(c);

        for (let row = 0; row < c; row++) {
          reshapedPoint[row] = flatArray.slice(row * depth, (row + 1) * depth);
        }
        reshaped[i] = reshapedPoint;
      }
    }
    // Handle default reshaping
    else {
      const [totalPixels, depth] = shape;
      reshaped = new Array(totalPixels);

      for (let i = 0; i < totalPixels; i++) {
        const slice = new Float32Array(depth);  // Use typed array to save memory
        for (let j = 0; j < depth; j++) {
          slice[j] = index < array.length ? array[index++] : undefined;
        }
        reshaped[i] = slice;
      }
    }

    return reshaped;
  }

  async _loadImage(url, paramName) {
    let numChannels = 1;  // Default to 1 channel (grayscale)

    // Set the number of channels based on paramName
    if (paramName === 'scales' || paramName === 'sh0' || paramName === 'means') {
      numChannels = 3;  // Expect 3 channels (RGB)
    } else if (paramName === 'quats') {
      numChannels = 4;  // Expect 4 channels (RGBA)
    }

    const img = await new Promise((resolve, reject) => {
      const image = new Image();
      image.onload = () => resolve(image);
      image.onerror = reject;
      image.src = url;
    });

    const canvas = document.createElement('canvas');
    canvas.width = img.width;
    canvas.height = img.height;

    const ctx = canvas.getContext('2d');
    ctx.drawImage(img, 0, 0);

    const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
    const data = new Uint8Array(imageData.data);  // Image data is in RGBA format by default

    // // Log the data for debugging purposes
    // console.log("Loaded image data:", data);

    // // Debug the number of channels in the loaded image
    const channels = numChannels;
    // console.log(`Expected number of channels for ${paramName}:`, channels);

    const width = canvas.width;
    const height = canvas.height;

    // Initialize reshaped data array based on the number of channels
    const reshapedData = new Float32Array(width * height * channels);

    for (let i = 0; i < width * height; i++) {
      const idx = i * 4; // The original image data contains 4 channels (RGBA)

      if (channels === 1) {
        // For grayscale, use the average of RGBA or just the red channel
        reshapedData[i] = data[idx]; // Normalize the grayscale channel
      } else if (channels === 3) {
        // For RGB, use the first three channels (R, G, B)
        reshapedData[i * 3] = data[idx];       // R
        reshapedData[i * 3 + 1] = data[idx + 1]; // G
        reshapedData[i * 3 + 2] = data[idx + 2]; // B
      } else if (channels === 4) {
        // For RGBA, use all four channels
        reshapedData[i * 4] = data[idx];     // R
        reshapedData[i * 4 + 1] = data[idx + 1]; // G
        reshapedData[i * 4 + 2] = data[idx + 2]; // B
        reshapedData[i * 4 + 3] = data[idx + 3]; // A
      }
    }

    return tf.tensor(reshapedData, [height, width, channels]);
  }
  transpose2DArray = (array) => {
      const numRows = array.length; // 原始行數
      const numCols = array[0].length; // 原始列數
      const transposed = Array.from({ length: numCols }, () => Array(numRows).fill(0));

      for (let i = 0; i < numRows; i++) {
          for (let j = 0; j < numCols; j++) {
              transposed[j][i] = array[i][j];
          }
      }

      return transposed;
  };
}

export default Decompressor;