import * as THREE from 'three';
import { clamp } from '../../Util.js';
import { UncompressedSplatArray } from '../UncompressedSplatArray.js';
import {LoaderStatus} from "../LoaderStatus.js";

class PngSaver {
    constructor(splats, maxShDegree) {
        this.splats = splats;
        this.maxShDegree = maxShDegree;
        this.tempRotation = new THREE.Quaternion();
    }

    extractData() {
        const { means, opacities, sh0, shN, scales, quats } = this.splats;

        this.xyz = means;  // Shape (N, 3)
        this.opacities = opacities;  // Shape (N,)
        this.featuresDc = sh0.map(f => f[0]);  // DC coefficients (Shape (N, 3))
        this.featuresExtra = shN.map(f => [].concat(...f[0].map((_, colIndex) => f.map(row => row[colIndex]))));  // Extra SH coefficients
        this.scales = scales;  // Shape (N, len(scale_names))
        this.rots = quats;  // Shape (N, len(rot_names))
    }

    /**
     * Parse to Uncompressed Splats and call progress callback after processing each batch
     * @param {Function} onProgress - Progress callback function
     * @param {number} batchSize - Number of splats to process per batch
     * @returns {UncompressedSplatArray} - Processed splat array
     */
    async parseToUncompressedSplats(onProgress, batchSize = 50000, progressInterval = 10000, progressIntervalMs = 100) {
        this.extractData();
        const splatArray = new UncompressedSplatArray(2);
        const totalSplats = this.xyz.length;
        let processedCount = 0;
        let lastReportedCount = 0;
        let lastProgressTime = Date.now();

        const processBatch = () => {
            const start = processedCount;
            const end = Math.min(processedCount + batchSize, totalSplats);

            const { xyz, opacities, featuresDc, featuresExtra, scales, rots } = this;

            for (let i = start; i < end; i++) {
                const splat = this.createUncompressedSplat(
                    xyz[i], opacities[i], featuresDc[i], featuresExtra[i], scales[i], rots[i]
                );
                splatArray.addSplat(splat);
                processedCount++;
            }

            const currentTime = Date.now();
            const isBatchComplete = processedCount - lastReportedCount >= progressInterval || processedCount === totalSplats;
            const timeElapsed = currentTime - lastProgressTime >= progressIntervalMs;

            if ((isBatchComplete || timeElapsed) && processedCount < totalSplats) {
                onProgress(processedCount, totalSplats, LoaderStatus.Processing);
                lastReportedCount = processedCount;
                lastProgressTime = currentTime;
            }

            if (processedCount === totalSplats) {
                onProgress(processedCount, totalSplats, LoaderStatus.Done);
            }
        };

        const processAllBatches = () => {
            return new Promise((resolve) => {
                const batchProcessing = () => {
                    processBatch();
                    if (processedCount < totalSplats) {
                        requestIdleCallback(batchProcessing);
                    } else {
                        resolve(splatArray);
                    }
                };

                batchProcessing();
            });
        };

        await processAllBatches();
        return splatArray;
    }

    createUncompressedSplat(xyz, opacity, featuresDc, featuresExtra, scaleCoeffs, quatCoeffs) {
        // Create a new splat
        const newSplat = UncompressedSplatArray.createSplat(2);

        // Precompute scale coefficients and apply directly
        const expScale = (coeff) => Math.exp(coeff || 0.01);
        newSplat[UncompressedSplatArray.OFFSET.SCALE0] = expScale(scaleCoeffs[0]);
        newSplat[UncompressedSplatArray.OFFSET.SCALE1] = expScale(scaleCoeffs[1]);
        newSplat[UncompressedSplatArray.OFFSET.SCALE2] = expScale(scaleCoeffs[2]);

        // Process DC coefficients with optimized access
        const SH_C0 = 0.28209479177387814;
        const featuresDc0 = SH_C0 * featuresDc[0];
        const featuresDc1 = SH_C0 * featuresDc[1];
        const featuresDc2 = SH_C0 * featuresDc[2];
        newSplat[UncompressedSplatArray.OFFSET.FDC0] = Math.floor((0.5 + featuresDc0) * 255);
        newSplat[UncompressedSplatArray.OFFSET.FDC1] = Math.floor((0.5 + featuresDc1) * 255);
        newSplat[UncompressedSplatArray.OFFSET.FDC2] = Math.floor((0.5 + featuresDc2) * 255);

        // Clamp opacity after processing
        newSplat[UncompressedSplatArray.OFFSET.OPACITY] = Math.floor((1 / (1 + Math.exp(-opacity))) * 255);

        // Process rotation directly into the array
        const quat = this.tempRotation.set(quatCoeffs[0], quatCoeffs[1], quatCoeffs[2], quatCoeffs[3]).normalize();
        newSplat[UncompressedSplatArray.OFFSET.ROTATION0] = quat.x;
        newSplat[UncompressedSplatArray.OFFSET.ROTATION1] = quat.y;
        newSplat[UncompressedSplatArray.OFFSET.ROTATION2] = quat.z;
        newSplat[UncompressedSplatArray.OFFSET.ROTATION3] = quat.w;

        // Process position
        newSplat[UncompressedSplatArray.OFFSET.X] = xyz[0];
        newSplat[UncompressedSplatArray.OFFSET.Y] = xyz[1];
        newSplat[UncompressedSplatArray.OFFSET.Z] = xyz[2];

        // Calculate the number of spherical harmonics coefficients per channel
        const sphericalHarmonicsFieldCount = 45; // Total coefficient count (assumed 45)
        const sphericalHarmonicsCoefficientsPerChannel = sphericalHarmonicsFieldCount / 3; // 3 coefficients per channel for RGB

        // Precompute SH fields for degree 1 and 2
        const sphericalHarmonicsDegree1Fields = [];
        const sphericalHarmonicsDegree2Fields = [];

        if (this.maxShDegree >= 1) {
            for (let rgb = 0; rgb < 3; rgb++) {
                for (let i = 0; i < 3; i++) {
                    sphericalHarmonicsDegree1Fields.push(i + sphericalHarmonicsCoefficientsPerChannel * rgb);
                }
            }
        }

        if (this.maxShDegree >= 2) {
            for (let rgb = 0; rgb < 3; rgb++) {
                for (let i = 0; i < 5; i++) {
                    sphericalHarmonicsDegree2Fields.push(i + sphericalHarmonicsCoefficientsPerChannel * rgb + 3);
                }
            }
        }

        // Batch process SH coefficients for degree 1 and 2 directly into the splat
        let index;
        if (this.maxShDegree >= 1) {
            sphericalHarmonicsDegree1Fields.forEach((index, i) => {
                newSplat[UncompressedSplatArray.OFFSET.FRC0 + i] = featuresExtra[index];
            });
        }

        if (this.maxShDegree >= 2) {
            sphericalHarmonicsDegree2Fields.forEach((index, i) => {
                newSplat[UncompressedSplatArray.OFFSET.FRC9 + i] = featuresExtra[index];
            });
        }
        // Return the processed splat
        return newSplat;
    }
}

export default PngSaver;