import { Box, Button, Center, Flex, Loader, Paper, Stack, Text, Title, UnstyledButton, useMantineTheme } from "@mantine/core"
import { useDisclosure } from "@mantine/hooks"
import {
  DashboardTimespan,
  DashboardTimespanEnum,
  SafetyReport,
  SafetyReportFactorType,
  SafetyReportFactorTypeEnum,
} from "@soar/shared/types"
import { SafetyReportFactorConfig } from "@soar/shared/utils"
import dayjs, { Dayjs } from "dayjs"
import IsBetween from "dayjs/plugin/isBetween"
import dynamic from "next/dynamic"
import { useMemo } from "react"
import { match } from "ts-pattern"
import { CustomLink } from "../customLink"
import { ModalOrDrawer } from "../modalOrDrawer"
import { StatusDisplay } from "../statusDisplay"
import { SectionHeader } from "./SectionHeader"
import { calculateBreakpoints } from "./graphHelper"

dayjs.extend(IsBetween)

const ResponsiveBar = dynamic(() => import("@nivo/bar").then((mod) => mod.ResponsiveBar), {
  loading: () => <p>Loading...</p>,
  ssr: false,
})

function countSafetyCategories(safetyReports: SafetyReport[], startDate: Dayjs) {
  const factorCountMap = new Map<SafetyReportFactorType, number>()
  for (const safetyReport of safetyReports) {
    if (dayjs(safetyReport.createdAt).isBefore(startDate)) {
      continue
    }

    for (const factor of safetyReport.factors) {
      const currentCount = factorCountMap.get(factor) ?? 0
      factorCountMap.set(factor, currentCount + 1)
    }
  }
  const factorCounts = [...factorCountMap.entries()]
    .map(([factor, count]) => {
      return { factor, count }
    })
    .sort((a, b) => {
      return b.count - a.count
    })

  return factorCounts
}

function calculateBarGraphData(data: SafetyReport[], option: DashboardTimespan, dateStart: Dayjs) {
  const { breakpoints, tickValues, xFormat } = calculateBreakpoints(option, dateStart)

  /*
  console.log({
    breakpoints,
    data,
    option,
    dateStart: dateStart.toDate()
  })
  */

  const now = dayjs()

  const relevantFactors = new Set<SafetyReportFactorType>()

  const graphData = breakpoints.reduce((memo, breakpoint, index) => {
    let nextBreakpoint = breakpoints[index + 1]
    if (nextBreakpoint == null) {
      return memo
    }
    if (option !== DashboardTimespanEnum.enum.all && option !== DashboardTimespanEnum.enum.today && nextBreakpoint.isAfter(now)) {
      nextBreakpoint = now
    }

    const factorMap = new Map<SafetyReportFactorType, number>()

    const safetyReports = data.filter((report) => {
      return dayjs(report.createdAt).isBetween(breakpoint, nextBreakpoint, "second", "[)")
    })

    for (const report of safetyReports) {
      for (const factor of report.factors) {
        const existingCount = factorMap.get(factor) ?? 0
        factorMap.set(factor, existingCount + 1)
        relevantFactors.add(factor)
      }
    }

    return [
      ...memo,
      {
        label: breakpoint.format(xFormat),
        day: breakpoint,
        counts: factorMap,
      },
    ]
  }, [] as { label: string; day: Dayjs; counts: Map<SafetyReportFactorType, number> }[])

  let maxValue = 0
  const series = graphData.map((timePeriod) => {
    const data: Partial<Record<SafetyReportFactorType, number>> = {}
    let timePeriodCount = 0
    for (const factor of relevantFactors) {
      const count = timePeriod.counts.get(factor)
      if (count != null) {
        data[factor] = count
        timePeriodCount = timePeriodCount + count
      }
    }
    if (timePeriodCount > maxValue) {
      maxValue = timePeriodCount
    }

    return {
      id: timePeriod.label,
      ...data,
    }
  })

  // const nextMaxValue = maxValue - (maxValue / 10 - Math.floor(maxValue / 10)) * 10 + 2

  return {
    breakpoints,
    graphData,
    maxValue: maxValue,
    tickValues: tickValues.slice(0, tickValues.length - 1),
    relevantFactors: [...relevantFactors],
    series,
  }
}

function calculateLineGraphData(data: SafetyReport[], option: DashboardTimespan, dateStart: Dayjs) {
  const { breakpoints, tickValues, xFormat } = calculateBreakpoints(option, dateStart)

  /*
  console.log({
    breakpoints,
    data,
    option,
    dateStart: dateStart.toDate()
  })
  */

  const now = dayjs()

  const relevantFactors = new Set<SafetyReportFactorType>()

  const graphData = breakpoints.reduce((memo, breakpoint, index) => {
    let nextBreakpoint = breakpoints[index + 1]
    if (nextBreakpoint == null) {
      return memo
    }
    if (option !== DashboardTimespanEnum.enum.all && option !== DashboardTimespanEnum.enum.today && nextBreakpoint.isAfter(now)) {
      nextBreakpoint = now
    }

    const factorMap = new Map<SafetyReportFactorType, number>()

    const safetyReports = data.filter((report) => {
      return dayjs(report.createdAt).isBetween(breakpoint, nextBreakpoint, "second", "[)")
    })

    for (const report of safetyReports) {
      for (const factor of report.factors) {
        const existingCount = factorMap.get(factor) ?? 0
        factorMap.set(factor, existingCount + 1)
        relevantFactors.add(factor)
      }
    }

    return [
      ...memo,
      {
        label: breakpoint.format(xFormat),
        day: breakpoint,
        counts: factorMap,
      },
    ]
  }, [] as { label: string; day: Dayjs; counts: Map<SafetyReportFactorType, number> }[])

  let maxValue = 0
  const series = [...relevantFactors].map((factor) => {
    const data = graphData.map((breakpointData) => {
      const count = breakpointData.counts.get(factor) ?? 0
      if (count > maxValue) {
        maxValue = count
      }
      return {
        x: breakpointData.label,
        y: count,
      }
    })

    return {
      id: factor,
      data,
    }
  })

  const nextMaxValue = maxValue - (maxValue / 10 - Math.floor(maxValue / 10)) * 10 + 2

  return {
    breakpoints,
    graphData,
    maxValue: nextMaxValue,
    tickValues,
    relevantFactors: [...relevantFactors],
    series,
  }
}

function SafetyFactorCountRow({
  factor,
  count,
}: {
  factor: SafetyReportFactorType
  count: number
}) {
  const factorConfig = SafetyReportFactorConfig[factor]
  return (
    <Flex key={factor} justify="space-between" w="100%">
      <Text>{factorConfig.label}</Text>
      <Text>{count}</Text>
    </Flex>
  )
}

export function TopSafetyReportFactorsSection({
  data = [],
  startDate,
  dateKey,
  timespanLabel,
  isLoading,
}: {
  dateKey: DashboardTimespan
  startDate: Dayjs
  data?: SafetyReport[]
  timespanLabel: string
  isLoading: boolean
}) {
  const [modalState, modalHandlers] = useDisclosure(false)

  const { factorCounts } = useMemo(() => {
    const factorCounts = countSafetyCategories(data, startDate)
    return {
      factorCounts,
    }
  }, [startDate, data])

  const { series, maxValue, relevantFactors, tickValues } = useMemo(() => {
    const graphData = calculateBarGraphData(data, dateKey, startDate)
    return graphData
  }, [data, dateKey, startDate])

  const hasData = data.length > 0

  return (
    <Flex w="100%" h="100%" direction="column" justify="space-between">
      <Box>
        <SectionHeader title="Most frequent risk categories" timePeriod={timespanLabel} />
        <Flex align="center" gap="xs" w="100%">
          <Stack display="none" w="100%">
            {factorCounts.slice(0, 6).map(({ factor, count }) => (
              <SafetyFactorCountRow key={factor} factor={factor} count={count} />
            ))}
          </Stack>

          <Box h={325} w="100%">
            {match({ isLoading, hasData })
              .with({ isLoading: true }, () => (
                <Center py="xl" mt="xl">
                  <Loader variant="bars" mt="xl" />
                </Center>
              ))
              .with({ hasData: true }, () => (
                <ResponsiveBar
                  data={series}
                  /*
                xScale={{
                  type: "point",
                }}
                yScale={{
                  type: "linear",
                  min: 0,
                  max: maxValue,
                }}
                */
                  enableGridY={false}
                  enableGridX={true}
                  keys={relevantFactors}
                  axisBottom={{
                    renderTick: (tick) => {
                      if (tick.tickIndex === 0 || tick.tickIndex === tickValues.length - 1) {
                        return (
                          <g transform={`translate(${tick.x},${tick.y + 22})`}>
                            <rect x={-14} y={-6} rx={3} ry={3} width={28} height={24} fill="transparent" />
                            <rect x={-12} y={-12} rx={2} ry={2} width={24} height={24} fill="transparent" />
                            <line stroke="rgb(220,220,220)" strokeWidth={1} y1={-22} y2={-12} />
                            <text
                              textAnchor="middle"
                              dominantBaseline="middle"
                              style={{
                                fontSize: 10,
                              }}
                            >
                              {tick.value}
                            </text>
                          </g>
                        )
                      }

                      return (
                        <g transform={`translate(${tick.x},${tick.y + 22})`}>
                          <rect x={-14} y={-6} rx={3} ry={3} width={28} height={24} fill="transparent" />
                          <rect x={-12} y={-12} rx={2} ry={2} width={24} height={24} fill="transparent" />
                          <line stroke="rgb(220,220,220)" strokeWidth={1} y1={-22} y2={-12} />
                        </g>
                      )
                    },
                  }}
                  axisLeft={{
                    tickValues: [0, maxValue],
                  }}
                  tooltip={(tooltipProps) => {
                    const parsedFactorResult = SafetyReportFactorTypeEnum.safeParse(tooltipProps.id)
                    const factorLabel = parsedFactorResult.success
                      ? SafetyReportFactorConfig[parsedFactorResult.data].label
                      : tooltipProps.id
                    return (
                      <Paper key={tooltipProps.id} radius="sm" shadow="xs" p="sm">
                        <Text ta="center" fw={700} fz="lg" mb="sm">
                          {tooltipProps.indexValue}
                        </Text>
                        <Text>
                          <Text span fw={700} color={tooltipProps.color}>
                            {factorLabel}:
                          </Text>{" "}
                          {tooltipProps.value}
                        </Text>
                      </Paper>
                    )
                  }}
                  /*
                sliceTooltip={(slice) => {
                  const seriesRollup = slice.slice.points
                    .map((serie) => {
                      return {
                        id: serie.serieId,
                        label: SafetyReportFactorConfig[serie.serieId as SafetyReportFactorType].label,
                        value: serie.data.y as number,
                        color: serie.serieColor,
                      }
                    })
                    .sort((a, b) => b.value - a.value)
                  return (
                    <Paper radius="sm" shadow="xs" p="sm">
                      <Text ta="center" fw={700} fz="lg" mb="sm">
                        {slice.slice.points[0].data.xFormatted}
                      </Text>
                      <Stack spacing={5}>
                        {seriesRollup.map((serie) => (
                          <Text key={serie.id}>
                            <Text fw={700} span c={serie.color}>
                              {serie.label}:{" "}
                            </Text>
                            {serie.value}
                          </Text>
                        ))}
                      </Stack>
                    </Paper>
                  )
                }}
                */
                  margin={{
                    top: 30,
                    right: 5,
                    bottom: 30,
                    left: 30,
                  }}
                />
              ))
              .otherwise(() => (
                <Box py={36}>
                  <StatusDisplay align="flex-start" label="No data for this time period" />
                </Box>
              ))}
          </Box>
        </Flex>
      </Box>
      <Box w="100%">
        <Flex justify="flex-end" pt="xs">
          {factorCounts.length > 0 && (
            <Button onClick={modalHandlers.open} variant="subtle" px={0}>
              View more
            </Button>
          )}
        </Flex>
      </Box>
      <ModalOrDrawer opened={modalState} onClose={modalHandlers.close} title={""}>
        <Title order={3} mb="md">
          Most frequent safety categories
        </Title>
        <Stack>
          {factorCounts.map(({ factor, count }) => (
            <SafetyFactorCountRow key={factor} factor={factor} count={count} />
          ))}
        </Stack>
      </ModalOrDrawer>
    </Flex>
  )
}
