import React, { useMemo } from "react";
import {
    VictoryChart,
    VictoryLine,
    VictoryAxis,
    VictoryTheme,
    VictoryScatter,
    VictoryBoxPlot,
    VictoryHistogram,
    VictoryStack,
} from "victory";
import {
    DataDescription,
    PredictionDataObject,
    StatisticsObject,
} from "./types";
import { useModelSuiteColorModeClassAtom } from "@/Layouts/model-suite";

const tickFormat = (tick) => {
    if (tick >= 1000000) {
        return `${tick / 1000000}m`;
    } else if (tick >= 1000) {
        return `${tick / 1000}k`;
    }
    return tick;
};

const stringTickFormat = (tick) => {
    return tick.toString();
};

const useChartStyles = () => {
    const colorModeClass = useModelSuiteColorModeClassAtom();

    const styles = useMemo(() => {
        const color = colorModeClass !== "dark" ? "#1e293b" : "#cbd5e1";
        const contrastColor = colorModeClass !== "dark" ? "#cbd5e1" : "#1e293b";

        return {
            color,
            contrastColor,
            dependentAxisStyle: {
                tickLabels: {
                    padding: 5,
                    fontSize: 10,
                    stroke: color,
                    fill: color,
                },
                grid: { strokeOpacity: 0.1, stroke: color },
                axis: { strokeOpacity: 0.25, stroke: color },
                ticks: { strokeOpacity: 0.25, stroke: color },
                axisLabel: {
                    padding: 40, // 25 in CentroidBoxPlotComponent
                    fontSize: 10,
                    stroke: color,
                    fill: color,
                },
            },
            independentAxisStyle: {
                axisLabel: {
                    padding: 25,
                    fontSize: 10,
                    stroke: color,
                    fill: color,
                },
                tickLabels: {
                    padding: 2,
                    fontSize: 10,
                    stroke: color,
                    fill: color,
                },
                grid: { strokeOpacity: 0.1, stroke: color },
                axis: { strokeOpacity: 0.25, stroke: color },
                ticks: { strokeOpacity: 0.25, stroke: color },
            },
        };
    }, [colorModeClass]);

    return { ...styles };
};

const RegressionChartComponent: React.FC<{
    predictionDataObject: PredictionDataObject;
    regressionDataPoints: Record<string, number>[]; // { x: number; y: number }[];
}> = ({ predictionDataObject, regressionDataPoints }) => {
    const { x, y, points } = predictionDataObject;

    const { dependentAxisStyle, independentAxisStyle } = useChartStyles();

    const maxY = useMemo(() => {
        return points?.length > 0
            ? Math.max(...Object.values(points).map((point) => point.y))
            : 0;
    }, [points]);

    if (
        !points ||
        points?.length == 0 ||
        points[0]?.x == points[points?.length - 1]?.x
    ) {
        return (
            <div className="m-3 h-[400px] animate-pulse rounded-md bg-slate-600/10 dark:bg-slate-100/10"></div>
        );
    }

    return (
        <div className="-mt-10 flex w-full flex-col items-center px-4">
            <VictoryChart
                theme={VictoryTheme.material}
                width={450}
                domainPadding={{ y: 15 }}
            >
                <VictoryAxis
                    label={x}
                    style={independentAxisStyle}
                    tickFormat={
                        points?.length > 0 &&
                        points[points?.length - 1]?.x > 10000
                            ? tickFormat
                            : null
                    }
                />
                <VictoryAxis
                    dependentAxis
                    label={y}
                    style={dependentAxisStyle}
                    tickFormat={maxY > 10000 ? tickFormat : null}
                />
                <VictoryLine
                    data={points ?? []}
                    style={{
                        data: { stroke: "#254b9b", strokeWidth: 5 },
                    }}
                />
                <VictoryScatter
                    style={{ data: { fill: "#F4511E" } }}
                    size={2}
                    data={regressionDataPoints?.map((item) => ({
                        x: item[x],
                        y: item[y],
                    }))}
                />
            </VictoryChart>
        </div>
    );
};

export const RegressionChart = React.memo(RegressionChartComponent);

const ClusteringElbowTestChartComponent: React.FC<{
    inertiaValues: number[];
    clusters: number;
}> = ({ inertiaValues, clusters }) => {
    const { dependentAxisStyle, independentAxisStyle } = useChartStyles();

    const lineData = useMemo(() => {
        return inertiaValues.map((y, i) => ({
            ["x"]: i + 1,
            ["y"]: y,
        }));
    }, [inertiaValues]);

    const maxY = useMemo(() => {
        return inertiaValues.length > 0 ? Math.max(...inertiaValues) : 0;
    }, [inertiaValues]);

    if (inertiaValues.length == 0) {
        return (
            <div className="m-3 h-[400px] animate-pulse rounded-md bg-slate-600/10 dark:bg-slate-100/10"></div>
        );
    }

    return (
        <div className="-mt-10 flex w-full flex-col items-center px-4">
            <VictoryChart
                theme={VictoryTheme.material}
                width={450}
                domainPadding={{ y: 15 }}
            >
                <VictoryAxis
                    label={"Number of Clusters"}
                    style={independentAxisStyle}
                    domain={[0, 10]}
                />
                <VictoryAxis
                    dependentAxis
                    label={"Inertia (Sum of Squared Differences)"}
                    style={dependentAxisStyle}
                    tickFormat={maxY > 10000 ? tickFormat : null}
                />
                <VictoryLine
                    data={lineData ?? []}
                    style={{
                        data: { stroke: "#8BC34A", strokeWidth: 2 },
                    }}
                />
                <VictoryScatter
                    style={{ data: { fill: "#8BC34A" } }}
                    size={4}
                    data={lineData.filter((point) => point.x !== clusters)}
                />
                <VictoryScatter
                    style={{
                        data: {
                            fill: "purple",
                            stroke: "#8BC34A",
                            strokeWidth: 2,
                        },
                    }}
                    size={7}
                    data={lineData.filter((point) => point.x == clusters)}
                />
            </VictoryChart>
        </div>
    );
};

export const ClusteringElbowTestChart = React.memo(
    ClusteringElbowTestChartComponent,
);

const transformDataForBoxPlot = (
    data: Record<number, DataDescription>,
    variable: string,
) => {
    return Object.keys(data)
        .filter((cluster) => !!data[cluster][variable])
        .map((cluster, i) => ({
            x: `Cluster ${i + 1}`,
            min: data[cluster][variable].min,
            max: data[cluster][variable].max,
            q1: data[cluster][variable]["25%"],
            q3: data[cluster][variable]["75%"],
            median: data[cluster][variable]["50%"],
        }))
        .reduce((map, point) => [point, ...map], []);
};

const centroidColors = [
    "#254b9b",
    "#e27a3f",
    "#774fa1",
    "#45b285",
    "#df5a49",
    "#efc94c",
];

const CentroidBoxPlotComponent: React.FC<{
    clusterDescriptions: Record<number, DataDescription>;
    variable: string;
}> = ({ clusterDescriptions, variable }) => {
    const { dependentAxisStyle, independentAxisStyle } = useChartStyles();

    const data = useMemo(() => {
        return transformDataForBoxPlot(clusterDescriptions, variable);
    }, [clusterDescriptions, variable]);

    const maxY = useMemo(() => {
        return data.length > 0
            ? Math.max(...data.map((point) => point.max))
            : 0;
    }, [data]);

    if (!data || data?.length == 0) {
        return (
            <div className="m-3 h-[300px] animate-pulse rounded-md bg-slate-600/10 dark:bg-slate-100/10"></div>
        );
    }

    return (
        <div className="flex w-full flex-col items-center">
            <VictoryChart
                theme={VictoryTheme.material}
                width={450}
                domainPadding={{ x: 20, y: 15 }}
                horizontal
            >
                <VictoryAxis style={independentAxisStyle} />
                <VictoryAxis
                    dependentAxis
                    label={variable}
                    style={dependentAxisStyle}
                    tickFormat={maxY > 10000 ? tickFormat : null}
                />
                <VictoryBoxPlot
                    data={data}
                    boxWidth={20}
                    // style={{
                    //     min: { stroke: centroidColors[0] },
                    //     max: { stroke: centroidColors[0] },
                    //     q1: { fill: "#efc94c" },
                    //     q3: { fill: "#e27a3f" },
                    //     median: { stroke: "white", strokeWidth: 2 },
                    // }}
                    style={{
                        min: {
                            stroke: ({ index }) =>
                                centroidColors[
                                    (data?.length - 1 - (index as number)) %
                                        centroidColors.length
                                ],
                            strokeWidth: 1,
                        },
                        max: {
                            stroke: ({ index }) =>
                                centroidColors[
                                    (data?.length - 1 - (index as number)) %
                                        centroidColors.length
                                ],
                            strokeWidth: 1,
                        },
                        q1: {
                            fill: ({ index }) =>
                                centroidColors[
                                    (data?.length - 1 - (index as number)) %
                                        centroidColors.length
                                ],
                            stroke: ({ index }) =>
                                centroidColors[
                                    (data?.length - 1 - (index as number)) %
                                        centroidColors.length
                                ],
                            strokeWidth: 1,
                            fillOpacity: 0.5,
                        },
                        q3: {
                            fill: ({ index }) =>
                                centroidColors[
                                    (data?.length - 1 - (index as number)) %
                                        centroidColors.length
                                ],
                            stroke: ({ index }) =>
                                centroidColors[
                                    (data?.length - 1 - (index as number)) %
                                        centroidColors.length
                                ],
                            strokeWidth: 1,
                            fillOpacity: 0.5,
                        },
                        median: {
                            stroke: ({ index, x }) =>
                                centroidColors[
                                    (data?.length - 1 - (index as number)) %
                                        centroidColors.length
                                ],
                            strokeWidth: 2,
                        },
                    }}
                />
            </VictoryChart>
        </div>
    );
};

export const CentroidBoxPlot = React.memo(CentroidBoxPlotComponent);

const getPairData = (data: Record<string, number>[], x: string, y: string) => {
    return data?.map((d) => ({ x: d[x], y: d[y] }));
};

const transformDataForPairplot = (data: Record<number, DataDescription>) => {
    return Object.keys(data).map((cluster, i) => {
        return Object.keys(data[cluster]).reduce(
            (map, variable) => ({
                ...map,
                [variable]: data[cluster][variable]?.mean || 0,
            }),
            {},
        );
    });
};

function generateMidpointArray(stats: StatisticsObject): { x: number }[] {
    const { count, min, "25%": q1, "50%": median, "75%": q3, max } = stats;

    const midpoints = [
        (min + q1) / 2,
        (q1 + median) / 2,
        (median + q3) / 2,
        (q3 + max) / 2,
    ];
    const numPerQuartile = Math.floor(count / 4);
    const remaining = count % 4;

    const midpointArray = midpoints.flatMap((midpoint, index) => {
        const extra = index < remaining ? 1 : 0;
        return Array(numPerQuartile + extra).fill({ x: midpoint });
    });
    return midpointArray;
}

const transformDataForHistogram = (data: Record<number, DataDescription>) => {
    return Object.keys(data).reduce(
        (clusterMap, cluster, i) => ({
            ...clusterMap,
            [cluster]: Object.keys(data[cluster]).reduce(
                (variableMap, variable) => ({
                    ...variableMap,
                    [variable]: generateMidpointArray(data[cluster][variable]),
                }),
                {},
            ),
        }),
        {},
    );
};

const PairplotComponent: React.FC<{
    variables?: string[];
    clusterDescriptions?: Record<number, DataDescription>;
}> = ({ variables, clusterDescriptions }) => {
    const { dependentAxisStyle, independentAxisStyle } = useChartStyles();

    const data = useMemo(() => {
        return transformDataForPairplot(clusterDescriptions);
    }, [clusterDescriptions]);

    const maxYmap = useMemo(() => {
        return data.length > 0
            ? Object.keys(data[0])?.reduce(
                  (map, key) => ({
                      ...map,
                      [key]: Math.max(...data.map((point) => point[key])) || 0,
                  }),
                  {},
              )
            : {};
    }, [data]);

    const histogramData = useMemo(() => {
        return transformDataForHistogram(clusterDescriptions);
    }, [clusterDescriptions]);

    return (
        <div className={`grid grid-cols-${variables.length}`}>
            {variables.map((yVar, i) =>
                variables.map((xVar, j) => {
                    if (!data || data?.length == 0) {
                        return (
                            <div
                                key={`${yVar}-${xVar}`}
                                className="m-3 h-[250px] animate-pulse rounded-md bg-slate-600/10 dark:bg-slate-100/10"
                            ></div>
                        );
                    }
                    if (i != j) {
                        return (
                            <div
                                key={`${yVar}-${xVar}`}
                                className="flex w-full flex-col items-center"
                            >
                                <VictoryChart
                                    theme={VictoryTheme.material}
                                    domainPadding={15}
                                    width={450}
                                >
                                    <VictoryAxis
                                        style={independentAxisStyle}
                                        label={xVar}
                                        tickFormat={
                                            maxYmap[xVar] > 10000
                                                ? tickFormat
                                                : null
                                        }
                                    />
                                    <VictoryAxis
                                        dependentAxis
                                        style={dependentAxisStyle}
                                        label={yVar}
                                        tickFormat={
                                            maxYmap[yVar] > 10000
                                                ? tickFormat
                                                : null
                                        }
                                    />
                                    <VictoryScatter
                                        data={getPairData(data, xVar, yVar)}
                                        style={{
                                            data: {
                                                fill: ({ index }) =>
                                                    centroidColors[
                                                        (index as number) %
                                                            centroidColors.length
                                                    ],
                                                fillOpacity: 0.5,
                                                stroke: ({ index }) =>
                                                    centroidColors[
                                                        (index as number) %
                                                            centroidColors.length
                                                    ],
                                                strokeWidth: 3,
                                            },
                                        }}
                                        size={9}
                                    />
                                </VictoryChart>
                            </div>
                        );
                    } else {
                        return (
                            <div
                                key={`histogram-${yVar}`}
                                className="flex w-full flex-col items-center"
                            >
                                <VictoryChart
                                    theme={VictoryTheme.material}
                                    domainPadding={15}
                                    width={450}
                                >
                                    <VictoryAxis
                                        style={independentAxisStyle}
                                        label={xVar}
                                        tickFormat={
                                            maxYmap && maxYmap[xVar] > 10000
                                                ? tickFormat
                                                : null
                                        }
                                    />
                                    <VictoryAxis
                                        dependentAxis
                                        style={dependentAxisStyle}
                                        label={"Count"}
                                    />
                                    <VictoryStack>
                                        {!!histogramData &&
                                            Object.keys(histogramData).map(
                                                (cluster, k) => (
                                                    <VictoryHistogram
                                                        key={`${k}-${xVar}`}
                                                        data={
                                                            histogramData[
                                                                cluster
                                                            ][xVar] || []
                                                        }
                                                        style={{
                                                            data: {
                                                                fill: centroidColors[
                                                                    k
                                                                ],
                                                                fillOpacity: 0.6,
                                                                strokeWidth: 0,
                                                                strokeOpacity: 0,
                                                            },
                                                        }}
                                                    />
                                                ),
                                            )}
                                    </VictoryStack>
                                </VictoryChart>
                            </div>
                        );
                    }
                }),
            )}
        </div>
    );
};

export const Pairplot = React.memo(PairplotComponent);

const ForecastChartComponent: React.FC<{
    values: Record<string, number>;
    originalValues: Record<string, number>;
    variable: string;
    dataTimeHorizons: number[];
    forecastTimeHorizons: number[];
}> = ({
    values,
    originalValues,
    variable,
    dataTimeHorizons,
    forecastTimeHorizons,
}) => {
    const { color, contrastColor, dependentAxisStyle, independentAxisStyle } =
        useChartStyles();

    const lineData = useMemo(() => {
        return (
            (Object.keys(values)?.length > 0 &&
                Object.keys(values).map((timeHorizon, i) => ({
                    ["x"]: Number(timeHorizon),
                    ["y"]: values[timeHorizon],
                }))) ||
            []
        );
    }, [values]);

    const originalData = useMemo(() => {
        return (
            (Object.keys(originalValues)?.length > 0 &&
                Object.keys(originalValues).map((timeHorizon, i) => ({
                    ["x"]: Number(timeHorizon),
                    ["y"]: originalValues[timeHorizon],
                }))) ||
            []
        );
    }, [originalValues]);

    const maxY = useMemo(() => {
        return lineData.length > 0
            ? Math.max(...lineData.map((value) => value.y))
            : 0;
    }, [lineData]);

    if (lineData.length == 0) {
        return (
            <div className="m-3 h-[350px] animate-pulse rounded-md bg-slate-600/10 dark:bg-slate-100/10"></div>
        );
    }

    return (
        <div className="-mt-10 flex w-full flex-col items-center px-4">
            <VictoryChart
                theme={VictoryTheme.material}
                width={450}
                domainPadding={{ y: 15 }}
            >
                <VictoryAxis
                    label={"Time"}
                    style={independentAxisStyle}
                    domain={[
                        dataTimeHorizons[0],
                        forecastTimeHorizons[forecastTimeHorizons.length - 1],
                    ]}
                    tickFormat={stringTickFormat}
                />
                <VictoryAxis
                    dependentAxis
                    label={variable}
                    style={dependentAxisStyle}
                    tickFormat={maxY > 10000 ? tickFormat : null}
                />
                <VictoryLine
                    data={originalData || []}
                    style={{
                        data: {
                            stroke: color,
                            strokeWidth: 2,
                        },
                    }}
                />
                <VictoryLine
                    data={lineData || []}
                    style={{
                        data: {
                            stroke: color,
                            strokeWidth: 2,
                            strokeDasharray: "1, 5",
                        },
                    }}
                />
                <VictoryScatter
                    style={{
                        data: {
                            fill: contrastColor,
                            stroke: color,
                            strokeWidth: 2,
                        },
                    }}
                    size={3}
                    data={lineData}
                />
                <VictoryScatter
                    style={{
                        data: {
                            fill: color,
                        },
                    }}
                    size={4}
                    data={originalData}
                />
            </VictoryChart>
        </div>
    );
};

export const ForecastChart = React.memo(ForecastChartComponent);
