import _ from "lodash";
import { PropsWithChildren, SetStateAction, startTransition, useLayoutEffect, useMemo, useState } from "react";
import { useParams } from "react-router-dom";
import Datepicker from "react-tailwindcss-datepicker";

import { getDateRange, getEndTimestamp, getStartTimestamp } from "@/adapters/monitoring";
import { AnalysisType, ColumnType, useFragment } from "@/apis/nannyml";
import { PersistentFilterProvider } from "@/components/Filters";
import { SortIcon } from "@/components/Icons";
import { RadioGroup, RadioGroupLabeledItem } from "@/components/RadioGroup";
import { Select, SelectItem, SelectVariantProps } from "@/components/Select";
import { Button } from "@/components/common/Button";
import { RangeSlider } from "@/components/common/RangeSlider/RangeSlider";
import { SelectCheckboxList } from "@/components/dashboard/SelectCheckbox/SelectCheckboxList";
import { usePlotConfig } from "@/components/monitoring/PlotConfig";
import { PerformancePlotMode, PlotType, SortOrder } from "@/constants/enums";
import { exceedsThreshold } from "@/domains/threshold";
import { performancePlotModeLabels, sortOrderLabels } from "@/formatters/filters";
import { metricLabels } from "@/formatters/monitoring";
import { useModelSchema } from "@/hooks/models";
import { calculateResultThreshold } from "@/hooks/monitoring";
import { formatDate } from "@/lib/dateUtils";
import { cn } from "@/lib/utils";

import { DateFilterContext } from "./DateFilter.context";
import { useDateFilterContext } from "./DateFilter.hook";
import {
  useResultFilter,
  useResultFilterConfig,
  useResultGrouper,
  useResultFilterContext,
  useResultLabeler,
  ResultFilterFragmentType,
  ResultFilterType,
  resultFilterFragment,
} from "./ResultFilters.hooks";

export const ResultFilterContextProvider = ({
  children,
  filterStoreName,
  results: resultFragments,
}: PropsWithChildren<{
  filterStoreName: string;
  results: ResultFilterFragmentType[];
}>) => {
  const results = useFragment(resultFilterFragment, resultFragments).map((result) => {
    const threshold = calculateResultThreshold(result);
    return {
      ...result,
      data: result.data.map((dp) => ({ ...dp, hasAlert: exceedsThreshold(dp.value, threshold) })),
    };
  });
  const defaultFilterConfig = {
    metrics: [],
    tags: [],
    sortOrder: SortOrder.RecencyOfAlerts,
  };

  return (
    <PersistentFilterProvider
      storeName={`ResultFilterStore.${filterStoreName}`}
      items={results}
      defaultFilterConfig={defaultFilterConfig}
    >
      <DateFilterContextProvider>{children}</DateFilterContextProvider>
    </PersistentFilterProvider>
  );
};

export const DateFilterContextProvider = ({ children }: PropsWithChildren<{}>) => {
  const { type: plotType } = usePlotConfig();
  const { filteredItems } = useResultFilterContext();

  // Determine min/max date range from filtered results
  const dateRange = useMemo(
    () =>
      filteredItems.flat().reduce(
        ([minDate, maxDate], result): [number, number] => {
          const [startTimestamp, endTimestamp] = getDateRange(result.data, plotType);
          return [Math.min(minDate, startTimestamp), Math.max(maxDate, endTimestamp)];
        },
        [Number.MAX_VALUE, -Number.MAX_VALUE] as [number, number]
      ),
    [filteredItems, plotType]
  );

  // State to store currently selected date range, defaulting to full range of results
  const [activeDateRange, setActiveDateRange] = useState<[number, number]>(dateRange);

  // Update active date range when available date range changes
  useLayoutEffect(() => setActiveDateRange(dateRange), [dateRange]);

  return (
    <DateFilterContext.Provider
      value={{
        activeDateRange,
        dateRange,
        setActiveDateRange: (dateRange: SetStateAction<[number, number]>) =>
          startTransition(() => setActiveDateRange(dateRange)),
        resetActiveDateRange: () => setActiveDateRange(dateRange),
      }}
    >
      {children}
    </DateFilterContext.Provider>
  );
};

export const FilterColumns = () => {
  // TODO: This is a hack to get the model ID from the URL. We should probably get this from the context instead
  const { modelId } = useParams();
  const schema = useModelSchema(parseInt(modelId!));

  const [{ columns }, setFilterConfig] = useResultFilterConfig();
  const { items: results } = useResultFilterContext();

  useResultFilter(
    (resultGroups, { columns }) => {
      // Filter out columns that are not in the schema. This may happen because of client side caching when the server
      // instance is wiped.
      columns = columns?.filter((columnName) => columnName in schema.columns);
      if (!columns || columns.length === 0) {
        return resultGroups;
      }

      return resultGroups.filter((results) =>
        results.some(
          (result) =>
            (result.columnName && columns.includes(result.columnName)) ||
            (result.columnNames && result.columnNames.some((columnName) => columns.includes(columnName)))
        )
      );
    },
    [schema.columns]
  );

  const getColumnType = (columnName: string) => schema.columns[columnName]?.columnType;
  const setSelectedColumns = (updateColumns: (currentMetrics: string[]) => string[]) => {
    setFilterConfig({ columns: updateColumns(columns ?? []) });
  };

  type Labels = { [column: string]: string };
  const groupedColumns = [
    { types: [ColumnType.Target], title: "Targets", labels: {} as Labels },
    { types: [ColumnType.Prediction, ColumnType.PredictionScore], title: "Model output", labels: {} as Labels },
    { types: [ColumnType.CategoricalFeature], title: "Categorical features", labels: {} as Labels },
    { types: [ColumnType.ContinuousFeature], title: "Continuous features", labels: {} as Labels },
  ];
  const availableColumns = _.uniq(results.flatMap((result) => result.columnName ?? result.columnNames ?? []));
  availableColumns.forEach((column) => {
    const type = getColumnType(column);
    const group = groupedColumns.find((group) => group.types.includes(type));
    if (group) {
      group.labels[column] = column;
    }
  });

  return (
    <SelectCheckboxList
      groupClassName="max-h-80 overflow-y-auto"
      values={availableColumns}
      names={groupedColumns}
      selectedValues={columns}
      setSelectedValues={setSelectedColumns}
      showSearch={true}
      showSelectOnlyCurrent={true}
      searchClassName="border-0"
      emptyPlaceholder="No columns found."
    />
  );
};

export const FilterPerformancePlotMode = () => {
  const [{ plotMode }, setFilterConfig] = useResultFilterConfig();

  useResultGrouper((results, { plotMode }) => {
    if (plotMode !== PerformancePlotMode.CompareRealizedToEstimatedPerformance) {
      return [];
    }

    // Group results by metric. Using a Map to store the results to preserve the original result order
    const resultsByMetric = new Map<string, ResultFilterType[]>();
    results.forEach((result) => {
      // Component name may be null, e.g. for DLE. To enable grouping with other calculators that do use components, the
      // metric name is used as the component name
      const segment = result.segment ? `${result.segment.segmentColumnName}:${result.segment.segment}.` : "";
      const name = `${segment}${result.metricName}.${result.componentName ?? result.metricName}`;
      const existingResults = resultsByMetric.get(name) ?? [];
      resultsByMetric.set(name, [...existingResults, result]);
    });

    // Order results so realized is always first. Taking advantage of the fact that 'estimated' precedes 'realized'
    // alphabetically
    return Array.from(resultsByMetric.values()).map((results) =>
      _.orderBy(results, (result) => result.analysisType, "desc")
    );
  }, []);

  useResultFilter((resultGroups, { plotMode }) => {
    if (plotMode === PerformancePlotMode.EstimatedPerformance) {
      return resultGroups.filter((results) =>
        results.some((result) => result.analysisType === AnalysisType.EstimatedPerformance)
      );
    } else if (plotMode === PerformancePlotMode.RealizedPerformance) {
      return resultGroups.filter((results) =>
        results.some((result) => result.analysisType === AnalysisType.RealizedPerformance)
      );
    } else {
      return resultGroups;
    }
  });

  return (
    <RadioGroup
      className="flex flex-col"
      value={plotMode ?? PerformancePlotMode.RealizedAndEstimatedPerformance}
      onValueChange={(plotMode) => setFilterConfig({ plotMode: plotMode as PerformancePlotMode })}
    >
      {Object.entries(PerformancePlotMode).map(([key, value]) => (
        <RadioGroupLabeledItem key={key} value={value} label={performancePlotModeLabels[value]} />
      ))}
    </RadioGroup>
  );
};

type SortValueFn = (results: readonly ResultFilterType[]) => number | string;

const sorters: Record<SortOrder, SortValueFn> = {
  [SortOrder.Method]: (results) => _.min(results.map((r) => r.metricName)) ?? "",
  [SortOrder.Metric]: (results) => _.min(results.map((r) => r.metricName)) ?? "",
  [SortOrder.Column]: (results) => _.min(results.map((r) => r.columnName ?? r.columnNames?.join(", "))) ?? "",
  [SortOrder.NrOfAlers]: (results) =>
    _.sumBy(results, (r) => r.data.reduce((acc, { hasAlert }) => acc + (hasAlert ? 1 : 0), 0)),
  [SortOrder.RecencyOfAlerts]: (results) =>
    _.max(results.map((r) => r.data.findLast(({ hasAlert }) => hasAlert)?.startTimestamp ?? "")) ?? "",
};

const defaultSortDirection: Record<SortOrder, "asc" | "desc"> = {
  [SortOrder.Method]: "asc",
  [SortOrder.Metric]: "asc",
  [SortOrder.Column]: "asc",
  [SortOrder.NrOfAlers]: "desc",
  [SortOrder.RecencyOfAlerts]: "desc",
};

const sortLabelers: Record<SortOrder, (value: number | string, analysisType: AnalysisType) => string> = {
  [SortOrder.NrOfAlers]: (value) => (value === 0 ? "No alerts" : (value as number) > 1 ? `${value} Alerts` : "1 Alert"),
  [SortOrder.RecencyOfAlerts]: (value) => (value ? formatDate(value) : "No alerts"),
  [SortOrder.Metric]: (value: any, analysisType) => metricLabels[analysisType](value) ?? value,
  [SortOrder.Method]: (value: any, analysisType) => metricLabels[analysisType](value) ?? value,
  [SortOrder.Column]: (value) => value as string,
};

export const FilterSortOrder = ({
  size,
  sortOptions,
}: SelectVariantProps & {
  sortOptions: SortOrder[];
}) => {
  const [{ sortDirection, sortOrder }, setFilterConfig] = useResultFilterConfig();

  useResultFilter((resultGroups, { sortDirection, sortOrder }) => {
    const valueFn = sorters[sortOrder];
    return _.orderBy(resultGroups, valueFn, sortDirection ?? "asc");
  });

  useResultLabeler((results, { sortOrder }) => {
    const sortValue = sorters[sortOrder](results);
    return sortLabelers[sortOrder](sortValue, results[0].analysisType);
  });

  return (
    <div className="flex gap-2">
      <Select
        className="w-[165px]"
        value={sortOrder}
        onValueChange={(value) =>
          setFilterConfig({
            sortOrder: value as SortOrder,
            sortDirection: defaultSortDirection[value as SortOrder],
          })
        }
        size={size}
      >
        {sortOptions.map((option) => (
          <SelectItem key={option} value={option}>
            {sortOrderLabels[option]}
          </SelectItem>
        ))}
      </Select>
      <Button
        cva={{ size: "small", border: "thin" }}
        onClick={() => setFilterConfig({ sortDirection: sortDirection === "asc" ? "desc" : "asc" })}
      >
        <SortIcon
          direction={sortDirection ?? "asc"}
          type={sortOrder === SortOrder.NrOfAlers ? "numeric" : "alphabetic"}
        />
      </Button>
    </div>
  );
};

export const FilterDateRangeSlider = ({ className }: { className?: string }) => {
  const { filteredItems: filteredResultGroups } = useResultFilterContext();
  const {
    activeDateRange: [minDate, maxDate],
    setActiveDateRange,
  } = useDateFilterContext();
  const { type: plotType } = usePlotConfig();

  const [startIndices, endIndices, analysisStartIdx] = useMemo(() => {
    const startIndicesSet = new Set<number>(),
      endIndicesSet = new Set<number>();
    filteredResultGroups.forEach((results) =>
      results.forEach((result) =>
        result.data.forEach((dp) => {
          startIndicesSet.add(getStartTimestamp(dp, plotType));
          endIndicesSet.add(getEndTimestamp(dp, plotType));
        })
      )
    );

    return [
      Array.from(startIndicesSet).sort(),
      Array.from(endIndicesSet).sort(),
      // FIXME: This will break if there are results with different time bases
      filteredResultGroups
        .at(0)
        ?.at(0)
        ?.data.findIndex((dp) => dp.isAnalysis) ?? -1,
    ];
  }, [filteredResultGroups, plotType]);

  const handleSliderValueChange = _.debounce(
    (values: number[]) => setActiveDateRange([startIndices[values[0]], endIndices[values[1]]]),
    // Distribution plots are more expensive to render, so we debounce the slider value change to improve UX
    plotType === PlotType.Distribution ? 250 : 0
  );

  // Find indices closest to selected min/max date
  const startIndex = selectIndexOfClosestValue(startIndices, minDate);
  const endIndex = selectIndexOfClosestValue(endIndices, maxDate);

  return (
    <RangeSlider
      className={className}
      onChange={handleSliderValueChange}
      minIndex={0}
      maxIndex={startIndices.length - 1}
      markIndex={analysisStartIdx}
      startValue={startIndex}
      endValue={endIndex}
      startIndices={startIndices}
      endIndices={endIndices}
      labelClassName="dark:text-slate-400"
    />
  );
};

export const FilterDateRangePicker = ({
  className,
  inputClassName,
}: {
  className?: string;
  inputClassName?: string;
}) => {
  const { dateRange, activeDateRange, setActiveDateRange } = useDateFilterContext();

  return (
    <Datepicker
      value={{ startDate: new Date(activeDateRange[0]), endDate: new Date(activeDateRange[1]) }}
      useRange={false}
      startFrom={new Date(dateRange[0])}
      minDate={new Date(dateRange[0])}
      maxDate={new Date(dateRange[1])}
      onChange={(value) => {
        if (value && value.startDate && value.endDate) {
          setActiveDateRange([new Date(value.startDate).getTime(), new Date(value.endDate).getTime()]);
        }
      }}
      containerClassName={(defaults) => cn(defaults, "border rounded-md", className)}
      inputClassName={(defaults) =>
        cn(
          defaults,
          "min-w-[195px] dark:bg-dark rounded-md p-2 outline-0 ring-inset focus:ring-1 dark:ring-white",
          "[&:not(:focus)]:hover:dark:bg-slate-900 [&:not(:focus)]:hover:cursor-pointer",
          inputClassName
        )
      }
      toggleClassName="hidden"
    />
  );
};

const selectIndexOfClosestValue = (values: number[], value: number) => {
  if (!values.length) {
    return -1;
  }
  return values.map((d) => Math.abs(d - value)).reduce((iMin, x, i, arr) => (x < arr[iMin] ? i : iMin), 0);
};
