import {BarStack} from '@visx/shape';
import {Group} from '@visx/group';
import {Grid} from '@visx/grid';
import {LinearGradient} from '@visx/gradient';
import {AxisBottom, AxisLeft} from '@visx/axis';
import {scaleBand, scaleLinear, scaleOrdinal} from '@visx/scale';
import {useTooltip, useTooltipInPortal, defaultStyles} from '@visx/tooltip';
import {LegendOrdinal} from '@visx/legend';
import {localPoint} from '@visx/event';
import {ParentSize} from '@visx/responsive';
import {useId} from 'react';
import {VIZ_PALETTES} from '../palettes';
import style from './StackedBarChart.module.css';
import type {SeriesPoint} from '@visx/shape/lib/types';

type Datum<Key extends DataKey> = Record<Key, number | string>;
type DataKey = string | number;

interface TooltipData<Key extends DataKey, Data extends Datum<Key>> {
	bar: SeriesPoint<Data>;
	key: Key;
	index: number;
	height: number;
	width: number;
	x: number;
	y: number;
	color: string;
}

interface Props<Key extends string | number, Data extends Datum<Key>> {
	data: Data[];
	series: Key[];
	width: number;
	height: number;
	palette?: 'sunset' | 'halp';
	margin?: {top: number; right: number; bottom: number; left: number};
	events?: boolean;
}

const defaultMargin = {top: 60, right: 30, bottom: 50, left: 50};

let tooltipTimeout: number;

function StackedBarChart<Key extends DataKey, Data extends Datum<Key>>({
	data,
	series,
	width,
	height,
	palette = 'halp',
	events = false,
	margin = defaultMargin,
}: Props<Key, Data>) {
	const id = useId();
	const {
		tooltipOpen,
		tooltipLeft,
		tooltipTop,
		tooltipData,
		hideTooltip,
		showTooltip,
	} = useTooltip<TooltipData<Key, Data>>();

	const {containerRef, TooltipInPortal} = useTooltipInPortal({
		// TooltipInPortal is rendered in a separate child of <body /> and positioned
		// with page coordinates which should be updated on scroll. consider using
		// Tooltip or TooltipWithBounds if you don't need to render inside a Portal
		scroll: true,
	});

	if (width < 10 || data.length === 0) return null;
	// bounds
	const xMax = width - margin.left - margin.right;
	const yMax = height - margin.top - margin.bottom;

	const xKey = Object.keys(data[0]).find(
		(key) => !series.includes(key as Key),
	) as Key;

	const xScale = scaleBand<string | number>({
		domain: data.map((d) => d[xKey]),
	}).rangeRound([0, xMax]);

	const colorIndex = `${Math.max(
		Math.min(series.length, 7),
		3,
	)}` as keyof (typeof VIZ_PALETTES)[typeof palette];

	const colorScale = scaleOrdinal<Key, string>({
		domain: series,
		range: VIZ_PALETTES[palette][colorIndex],
	});

	const yScale = scaleLinear<number>({
		domain: [
			0,
			Math.max(
				...data.map((d) => {
					return series.reduce((sum, key) => {
						return sum + Number(d[key]);
					}, 0);
				}),
			),
		],
		nice: true,
	}).range([yMax, 0]);

	const xGet = (d: Data) => d[xKey];

	const backgroundGradient =
		palette === 'sunset'
			? {
					to: 'var(--color-orange-3)',
					from: 'var(--color-red-2)',
				}
			: {
					to: 'var(--color-purple-3)',
					from: 'var(--color-violet-2)',
				};

	return (
		<div style={{position: 'relative'}}>
			<svg ref={containerRef} width={width} height={height}>
				<LinearGradient id={`stackedbarchart-${id}`} {...backgroundGradient} />
				<rect
					x={0}
					y={0}
					width={width}
					height={height}
					fill={`url('#stackedbarchart-${id}')`}
					rx={14}
				/>
				<Group top={margin.top} left={margin.left}>
					<Grid
						xScale={xScale}
						yScale={yScale}
						width={xMax}
						height={yMax}
						stroke="white"
						strokeOpacity={0.1}
						xOffset={xScale.bandwidth() / 2}
					/>
					<BarStack<Data, Key>
						data={data}
						keys={series}
						x={xGet}
						xScale={xScale}
						yScale={yScale}
						color={colorScale}
					>
						{(barStacks) =>
							barStacks.map((barStack) =>
								barStack.bars.map((bar) => (
									<rect
										key={`bar-stack-${barStack.index}-${bar.index}`}
										x={bar.x}
										y={bar.y}
										height={bar.height}
										width={bar.width}
										fill={bar.color}
										onClick={() => {
											if (events) alert(`clicked: ${JSON.stringify(bar)}`);
										}}
										onMouseLeave={() => {
											tooltipTimeout = window.setTimeout(() => {
												hideTooltip();
											}, 300);
										}}
										onMouseMove={(event) => {
											if (tooltipTimeout) clearTimeout(tooltipTimeout);
											// TooltipInPortal expects coordinates to be relative to containerRef
											// localPoint returns coordinates relative to the nearest SVG, which
											// is what containerRef is set to in this example.
											const eventSvgCoords = localPoint(event);
											const left = bar.x + bar.width / 2;
											showTooltip({
												tooltipData: bar,
												tooltipTop: eventSvgCoords?.y,
												tooltipLeft: left,
											});
										}}
									/>
								)),
							)
						}
					</BarStack>
					<AxisLeft hideTicks scale={yScale} />
					<AxisBottom top={yMax} scale={xScale} />
				</Group>
			</svg>
			<div
				style={{
					position: 'absolute',
					top: margin.top / 2 - 10,
					width: '100%',
					display: 'flex',
					justifyContent: 'center',
					fontSize: '14px',
				}}
			>
				<LegendOrdinal
					className={style.Legend}
					scale={colorScale}
					direction="row"
					labelMargin="0 15px 0 0"
				/>
			</div>
			{tooltipOpen && tooltipData ? (
				<TooltipInPortal
					top={tooltipTop}
					left={tooltipLeft}
					style={defaultStyles}
				>
					<div style={{color: colorScale(tooltipData.key)}}>
						<strong>{tooltipData.key}</strong>
					</div>
					<div>{tooltipData.bar.data[tooltipData.key]}</div>
					<div>
						<small>{xGet(tooltipData.bar.data)}</small>
					</div>
				</TooltipInPortal>
			) : null}
		</div>
	);
}

export function StackedBarChartResponsive<
	Key extends DataKey,
	Data extends Datum<Key>,
>(props: Omit<Props<Key, Data>, 'width'>) {
	return (
		<ParentSize className={style.Responsive}>
			{({width}) => <StackedBarChart width={width} {...props} />}
		</ParentSize>
	);
}
