import React, { useEffect, useRef, useState, useCallback } from "react";
import {
  Stage,
  Image as KonvaImage,
  Layer,
  Line,
  Circle,
  Shape,
} from "react-konva";
import { useImage } from "react-konva-utils";
import {
  Flex,
  Text,
  Heading,
  SliderField,
  Button,
} from "@aws-amplify/ui-react";
import { FaPen, FaEraser, FaDrawPolygon, FaUndo, FaRedo } from "react-icons/fa";
import { MdOutlineAutoFixHigh, MdRestartAlt } from "react-icons/md";

const ImageDrawingCanvas = ({
  imageUrl,
  stageRef,
  linesLayerRef,
  setHasDrawnLines,
  segmentations,
  setSegmentations,
  selectedSegmentations,
  setSelectedSegmentations,
  handleGetSegmentResult,
  loadingSegmentation,
}) => {
  const [image] = useImage(imageUrl);
  const imageWidth = image ? image.naturalWidth : 0;
  const imageHeight = image ? image.naturalHeight : 0;
  const maxStageHeight = window.innerHeight * 0.6;
  const stageHeight = maxStageHeight;
  const stageWidth = (imageWidth / imageHeight) * stageHeight;
  const [strokeWidth, setStrokeWidth] = useState(20);
  const [pointerPosition, setPointerPosition] = useState({ x: -1, y: -1 });
  const [tool, setTool] = useState("select");
  const [lines, setLines] = useState([]);
  const [polygons, setPolygons] = useState([]);
  const [currentPolygon, setCurrentPolygon] = useState([]);
  const [polygonMasks, setPolygonMasks] = useState([]);
  const [history, setHistory] = useState([]);
  const [historyStep, setHistoryStep] = useState(0);
  const isDrawing = useRef(false);
  const lastPointerPosition = useRef(null);

  const getRelativePointerPosition = (stage) => {
    const transform = stage.getAbsoluteTransform().copy();
    transform.invert();
    const pos = stage.getPointerPosition();
    return transform.point(pos);
  };

  const toolNames = {
    select: "AI Mask Selection",
    polygon: "Connect dots to draw your mask",
    pen: "Brush tool",
    eraser: "Eraser",
  };

  useEffect(() => {
    initialize();
  }, [imageUrl, image]);

  const initialize = async () => {
    setLines([]);
    setPolygons([]);
    setCurrentPolygon([]);
    setPolygonMasks([]);
    setSelectedSegmentations([]);
    setSegmentations([]);
    setPointerPosition({ x: -1, y: -1 });
    setHistory([]);
    setHistoryStep(0);
    await handleGetSegmentResult();
    setHasDrawnLines(false);
  };

  const isNearStartingPoint = (x, y) => {
    if (currentPolygon.length >= 4) {
      const [startX, startY] = currentPolygon;
      const distance = Math.sqrt(
        Math.pow(x - startX, 2) + Math.pow(y - startY, 2)
      );
      return distance < 20; // Adjust this threshold as needed
    }
    return false;
  };

  const handleMouseDown = (e) => {
    setHasDrawnLines(true);
    const stage = e.target.getStage();
    const pos = getRelativePointerPosition(stage);

    if (tool === "polygon") {
      if (isNearStartingPoint(pos.x, pos.y) && currentPolygon.length >= 6) {
        const closedPolygon = [
          ...currentPolygon,
          currentPolygon[0],
          currentPolygon[1],
        ];
        setPolygons([...polygons, closedPolygon]);
        convertPolygonToMask(closedPolygon);
        setCurrentPolygon([]);
      } else {
        const newPolygon = [...currentPolygon, pos.x, pos.y];
        setCurrentPolygon(newPolygon);
      }
    } else {
      isDrawing.current = true;
      lastPointerPosition.current = pos;
      setLines([...lines, { tool, points: [pos.x, pos.y], strokeWidth }]);
    }

    saveToHistory();
  };

  const handleMouseMove = (e) => {
    const stage = e.target.getStage();
    const point = getRelativePointerPosition(stage);

    if (tool !== "polygon" && isDrawing.current) {
      const newPoints = [];
      const lastPos = lastPointerPosition.current;
      const currentPos = point;

      // Calculate control points for Bézier curve
      const midPoint = {
        x: (lastPos.x + currentPos.x) / 2,
        y: (lastPos.y + currentPos.y) / 2,
      };

      // Add points to create a smooth curve
      newPoints.push(lastPos.x, lastPos.y);
      newPoints.push(midPoint.x, midPoint.y);
      newPoints.push(currentPos.x, currentPos.y);

      setLines((prevLines) => {
        const updatedLines = [...prevLines];
        const lastLine = updatedLines[updatedLines.length - 1];

        if (lastLine && lastLine.points) {
          lastLine.points = lastLine.points.concat(newPoints);
          return updatedLines;
        }

        // If there's no last line or it doesn't have points, create a new one
        return [...prevLines, { tool, points: newPoints, strokeWidth }];
      });

      lastPointerPosition.current = currentPos;
    }
    setPointerPosition(point);
  };

  const handleMouseUp = () => {
    if (tool !== "polygon") {
      isDrawing.current = false;
    }
  };

  const handleSegmentationClick = useCallback(
    (label) => {
      setSelectedSegmentations((prev) => {
        const newSelection = prev.includes(label)
          ? prev.filter((i) => i !== label)
          : [...prev, label];
        saveToHistory(segmentations, newSelection);
        return newSelection;
      });
    },
    [segmentations]
  );

  const handleSegmentClick = (e) => {
    if (loadingSegmentation) return;
    if (tool !== "select" || segmentations?.length === 0) return;
    setHasDrawnLines(true);

    const stage = e.target.getStage();
    const pointerPosition = stage.getPointerPosition();
    const x = Math.floor(
      (pointerPosition.x / stageWidth) * segmentations[0].length
    );
    const y = Math.floor(
      (pointerPosition.y / stageHeight) * segmentations.length
    );

    if (y < segmentations.length && x < segmentations[y].length) {
      const clickedLabel = segmentations[y][x];
      handleSegmentationClick(clickedLabel);
    }
  };

  const convertPolygonToMask = (polygon) => {
    const newMask = {
      points: polygon,
      fill: "rgba(0, 0, 0, 0.5)",
    };
    setPolygonMasks([...polygonMasks, newMask]);
  };

  const saveToHistory = (newSegmentations, newSelection) => {
    const newState = {
      segmentations: newSegmentations,
      selectedSegmentations: newSelection || selectedSegmentations,
      lines: [...lines],
      polygons: [...polygons],
      currentPolygon: [...currentPolygon],
      masks: [...polygonMasks],
    };
    const newHistory = history.slice(0, historyStep + 1);
    newHistory.push(newState);
    setHistory(newHistory);
    setHistoryStep(newHistory.length - 1);
  };

  const undo = () => {
    if (historyStep > 0) {
      setHistoryStep(historyStep - 1);
      const prevState = history[historyStep - 1];
      setLines(prevState.lines);
      setPolygons(prevState.polygons);
      setCurrentPolygon(prevState.currentPolygon);
      setPolygonMasks(prevState.masks);
      setSegmentations(prevState.segmentations);
      setSelectedSegmentations(prevState.selectedSegmentations);
    }
    if (historyStep === 0) {
      initialize();
    }
  };

  const redo = () => {
    if (historyStep < history.length - 1) {
      setHistoryStep(historyStep + 1);
      const nextState = history[historyStep + 1];
      setLines(nextState.lines);
      setPolygons(nextState.polygons);
      setCurrentPolygon(nextState.currentPolygon);
      setPolygonMasks(nextState.masks);
      setSegmentations(nextState.segmentations);
      setSelectedSegmentations(nextState.selectedSegmentations);
    }
  };

  const handleReset = () => {
    initialize();
  };

  return (
    <Flex direction="column" alignItems="center">
      <Heading>{toolNames[tool]}</Heading>
      <Flex direction="row" alignItems="center">
        <MdOutlineAutoFixHigh
          size={25}
          onClick={() => setTool("select")}
          style={{
            cursor: "pointer",
            color: tool === "select" ? "green" : "black",
          }}
        />
        <FaDrawPolygon
          size={25}
          onClick={() => setTool("polygon")}
          style={{
            cursor: "pointer",
            color: tool === "polygon" ? "green" : "black",
          }}
        />
        <FaPen
          size={25}
          onClick={() => setTool("pen")}
          style={{
            cursor: "pointer",
            color: tool === "pen" ? "green" : "black",
          }}
        />
        <FaEraser
          size={25}
          onClick={() => setTool("eraser")}
          style={{
            cursor: "pointer",
            color: tool === "eraser" ? "green" : "black",
          }}
        />
        {(tool === "pen" || tool === "eraser") && (
          <>
            <Text>Pen Width</Text>
            <SliderField
              min={10}
              max={60}
              step={1}
              isValueHidden={true}
              value={strokeWidth}
              onChange={setStrokeWidth}
              color="green"
              filledTrackColor="green"
            />
          </>
        )}
        <Button onClick={undo} disabled={historyStep === 0}>
          <FaUndo />
        </Button>
        <Button
          onClick={redo}
          disabled={historyStep === history.length - 1 || historyStep === 0}
        >
          <FaRedo />
        </Button>
        <Button onClick={handleReset} title="Reset Canvas">
          <MdRestartAlt size={20} />
        </Button>
      </Flex>
      <Stage
        width={stageWidth}
        height={stageHeight}
        style={{ cursor: "crosshair" }}
        onMouseDown={tool === "select" ? handleSegmentClick : handleMouseDown}
        onMouseMove={handleMouseMove}
        onMouseUp={handleMouseUp}
        onTouchStart={tool === "select" ? handleSegmentClick : handleMouseDown}
        onTouchMove={handleMouseMove}
        onTouchEnd={handleMouseUp}
        ref={stageRef}
      >
        <Layer>
          <KonvaImage
            key={imageUrl}
            image={image}
            width={stageWidth}
            height={stageHeight}
          />
        </Layer>
        <Layer>
          {segmentations?.length > 0 && (
            <KonvaImage
              image={(() => {
                const canvas = document.createElement("canvas");
                canvas.width = segmentations[0]?.length || 1;
                canvas.height = segmentations.length || 1;
                const ctx = canvas.getContext("2d");
                const imageData = ctx.createImageData(
                  canvas.width,
                  canvas.height
                );
                for (let y = 0; y < canvas.height; y++) {
                  for (let x = 0; x < canvas.width; x++) {
                    const i = (y * canvas.width + x) * 4;
                    const label = segmentations[y]?.[x];
                    if (selectedSegmentations.includes(label)) {
                      imageData.data[i] = 0;
                      imageData.data[i + 1] = 255;
                      imageData.data[i + 2] = 0;
                      imageData.data[i + 3] = 77;
                    }
                  }
                }
                ctx.putImageData(imageData, 0, 0);
                return canvas;
              })()}
              width={stageWidth}
              height={stageHeight}
            />
          )}
        </Layer>
        <Layer ref={linesLayerRef}>
          {lines.map((line, i) => (
            <Line
              key={i}
              points={line.points}
              stroke={
                line.tool === "eraser"
                  ? "rgba(0, 0, 0, 1.0)"
                  : "rgba(223, 75, 38, 0.5)"
              }
              strokeWidth={line.strokeWidth}
              tension={0.5}
              lineCap="round"
              globalCompositeOperation={
                line.tool === "eraser" ? "destination-out" : "source-over"
              }
            />
          ))}
          {polygons.map((polygon, i) => (
            <Shape
              key={i}
              sceneFunc={(context, shape) => {
                context.beginPath();
                context.moveTo(polygon[0], polygon[1]);
                for (let i = 2; i < polygon.length; i += 2) {
                  context.lineTo(polygon[i], polygon[i + 1]);
                }
                context.closePath();
                context.fillStrokeShape(shape);
              }}
              fill="rgba(223, 75, 38, 0.6)"
              stroke="rgba(223, 75, 38, 0.6)"
              strokeWidth={2}
            />
          ))}
          {currentPolygon.length > 0 && (
            <Line
              points={currentPolygon}
              stroke="rgba(223, 75, 38, 0.6)"
              strokeWidth={2}
              tension={0}
            />
          )}
          {currentPolygon.map(
            (point, i) =>
              i % 2 === 0 && (
                <Circle
                  key={i}
                  x={point}
                  y={currentPolygon[i + 1]}
                  radius={4}
                  fill="rgba(223, 75, 38, 0.6)"
                />
              )
          )}
        </Layer>
      </Stage>
    </Flex>
  );
};

export const convertCanvasToBinary = (linesLayer) => {
  const download = false;
  return new Promise((resolve, reject) => {
    try {
      const canvas = linesLayer.toCanvas();
      const ctx = canvas.getContext("2d");
      const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
      const data = imageData.data;

      for (let i = 0; i < data.length; i += 4) {
        // If pixel is not black (0,0,0), make it white (255,255,255)
        if (data[i] > 0 || data[i + 1] > 0 || data[i + 2] > 0) {
          data[i] = 255; // Red
          data[i + 1] = 255; // Green
          data[i + 2] = 255; // Blue
        }
        data[i + 3] = 255; // Alpha (always fully opaque)
      }

      ctx.putImageData(imageData, 0, 0);

      canvas.toBlob((blob) => {
        if (blob) {
          if (download) {
            const url = URL.createObjectURL(blob);
            const a = document.createElement("a");
            a.href = url;
            a.download = "image.png";
            document.body.appendChild(a);
            a.click();
            document.body.removeChild(a);
            URL.revokeObjectURL(url);
          }

          resolve(blob);
        } else {
          reject(new Error("Canvas to Blob conversion failed"));
        }
      }, "image/png");
    } catch (error) {
      reject(error);
    }
  });
};

export default ImageDrawingCanvas;
