import { differenceBy, intersectionBy, uniqBy } from 'lodash-es';
import { useCallback, useState } from 'react';
import { AgGridReactRef } from '@/lib/ag-grid/types';

export default function useGridSelectedRows<T extends { id: string | number }>(
  _: React.MutableRefObject<AgGridReactRef | undefined>,
  shownRows: T[],
) {
  const [selectedRows, setSelectedRows] = useState<T[]>([]);
  const [firstSelectedRow, setFirstSelectedRow] = useState<T | null>(null);

  const handleRowCheck = (row: T, withShift?: boolean) => {
    if (selectedRows.length === 0) {
      setFirstSelectedRow(row);
    }
    setSelectedRows((prev) => {
      if (withShift && firstSelectedRow) {
        const firstSelectedRowIndex = shownRows.findIndex(
          ({ id }) => id === firstSelectedRow.id,
        );
        const currentRowIndex = shownRows.findIndex(({ id }) => id === row.id);
        const [start, end] =
          firstSelectedRowIndex < currentRowIndex
            ? [firstSelectedRowIndex, currentRowIndex]
            : [currentRowIndex, firstSelectedRowIndex];
        return shownRows.slice(start, end + 1);
      }
      const checkedRow = prev.find(({ id }) => id === row.id);
      if (checkedRow) {
        return prev.filter(({ id }) => id !== checkedRow.id);
      }

      return [...prev, row];
    });
  };

  const allShownRowsChecked =
    selectedRows.length && shownRows.length
      ? intersectionBy(selectedRows, shownRows, 'id').length ===
        shownRows.length
      : false;

  const handleAllShownRowsCheck = () => {
    setSelectedRows((prev) => {
      if (!allShownRowsChecked) {
        return uniqBy([...prev, ...shownRows], 'id');
      }
      return differenceBy(prev, shownRows, 'id');
    });
    setFirstSelectedRow(null);
  };

  const deselectAll = () => {
    setFirstSelectedRow(null);
    setSelectedRows([]);
  };

  const handleGroupCheck = useCallback((groupChildren: T[]) => {
    setSelectedRows((prev) => {
      const groupWasChecked =
        intersectionBy(prev, groupChildren, 'id').length ===
        groupChildren.length;
      if (groupWasChecked) {
        return differenceBy(prev, groupChildren, 'id');
      }
      return uniqBy([...prev, ...groupChildren], 'id');
    });
  }, []);

  const resolveGroupCheck = useCallback(
    (groupChildren: T[]) =>
      intersectionBy(selectedRows, groupChildren, 'id').length ===
      groupChildren.length,
    [selectedRows],
  );

  return {
    selectedRows,
    deselectAll,
    allShownRowsChecked,
    firstSelectedRow,
    handleAllShownRowsCheck,
    handleRowCheck,
    handleGroupCheck,
    resolveGroupCheck,
  };
}
