import React, { useState, useRef, useEffect } from "react";
import {
  Flex,
  Image,
  ScrollView,
  Text,
  Button,
  Loader,
  Alert,
  SwitchField,
} from "@aws-amplify/ui-react";
import ImageDrawingCanvas from "./imageDrawingCanvas";
import CompareSlider from "./compareSlider";
import ImageSearchModal from "./imageSearchModal";
import {
  resizeMaskImage,
  encodeBlobToBase64,
  base64ToBlob,
  saveBlobToFile,
} from "../imageUtils";
import { post } from "aws-amplify/api";
import { genImageCaptionAPI } from "../../../components/api";
import Modal from "react-modal";

const NEGATIVE_PROMPT =
  "unrelated objects, outdoor scenes, people, animals, vehicles, text or watermarks, blurry, distorted, low-resolution, overexposed, underexposed, noisy, cartoonish, unrealistic, clutter, messy, damaged furniture, poor lighting, uncoordinated colors, incomplete render, vintage, retro, outdated, gothic, medieval, fantasy";
Modal.setAppElement("#root");

const CustomizePage = ({ selectedItems, setSelectedItems, showExInfo }) => {
  const [prompt, setPrompt] = useState("");
  const [selectedImageIndex, setSelectedImageIndex] = useState(0);
  const [showSearchModal, setShowSearchModal] = useState(false);
  const linesLayerRef = useRef(null);
  const stageRef = useRef(null);
  const [product, setProduct] = useState(null);
  const [loading, setLoading] = useState(false);
  const [imageViewURL, setImageViewURL] = useState(null);
  const [errorMessage, setErrorMessage] = useState(null);
  const [hasDrawnLines, setHasDrawnLines] = useState(false);
  const [isCompareDisabled, setIsCompareDisabled] = useState(true);
  const [viewComponent, setViewComponent] = useState("canvas"); // canvas or slider
  const [segmentations, setSegmentations] = useState([]);
  const [selectedSegmentations, setSelectedSegmentations] = useState([]);
  const [loadingSegmentation, setLoadingSegmentation] = useState(false);
  const seed = Math.floor(Math.random() * 100) + 1;

  useEffect(() => {
    const selectedItem = selectedItems[selectedImageIndex];
    if (imageViewURL === null) {
      setImageViewURL(
        selectedItem.inpainted[selectedItem.inpainted.length - 1]
      );
    }
    if (selectedItem.inpainted.length > 1) {
      setIsCompareDisabled(false);
    } else {
      setIsCompareDisabled(true);
    }
  }, [selectedItems]);

  useEffect(() => {
    if (hasDrawnLines) setErrorMessage(null);
  }, [hasDrawnLines]);

  const handleGetSegmentResult = async () => {
    setLoadingSegmentation(true);
    const selectedItem = selectedItems[selectedImageIndex];
    const input_image = await encodeBlobToBase64(
      selectedItem.inpainted[selectedItem.inpainted.length - 1]
    );
    const payload = {
      body: {
        image: input_image,
      },
    };
    try {
      const restOperation = post({
        apiName: "openaiAPI",
        path: "/segment",
        options: {
          body: payload,
          headers: {
            "Content-Type": "application/json",
          },
        },
      });
      const { body } = await restOperation.response;
      const response = await body.json();
      const result = JSON.parse(response.body);
      setSegmentations(result["seg_result"]);
    } catch (error) {
      console.error("Error getting segmentation result:", error);
      setErrorMessage({
        heading: "Error in Segmentation",
        message: "Failed to get segmentation result. Please try again.",
      });
    } finally {
      setLoadingSegmentation(false);
    }
  };

  const handleViewComponentChange = () => {
    if (viewComponent === "canvas") {
      setViewComponent("slider");
    } else {
      setViewComponent("canvas");
    }
  };

  const handleImageClick = (index) => {
    setSelectedImageIndex(index);
    const selectedItem = selectedItems[index];
    setImageViewURL(selectedItem.inpainted[selectedItem.inpainted.length - 1]);
  };

  const handleSearchButton = () => {
    if (!hasDrawnLines) {
      setErrorMessage({
        heading: "No Mask Drawn",
        message: "Please draw at least one mask before searching.",
      });
      return;
    }
    setShowSearchModal(true);
  };

  const handleCaption = async () => {
    try {
      const imageURL = product.src;
      if (!imageURL) return prompt;

      try {
        return await genImageCaptionAPI(imageURL);
      } catch (error) {
        setErrorMessage({
          heading: "Error Processing Image",
          message: "Please try again.",
        });
        return null;
      }
    } catch (error) {
      return prompt;
    }
  };

  const extractMask = async () => {
    if (!stageRef.current || segmentations.length === 0) return null;

    const originalWidth = segmentations[0].length;
    const originalHeight = segmentations.length;

    const canvas = document.createElement("canvas");
    const context = canvas.getContext("2d");
    canvas.width = originalWidth;
    canvas.height = originalHeight;

    // Fill the canvas with black (background)
    context.fillStyle = "black";
    context.fillRect(0, 0, canvas.width, canvas.height);

    // Render selected segmentations
    const imageData = context.createImageData(canvas.width, canvas.height);
    for (let y = 0; y < segmentations.length; y++) {
      for (let x = 0; x < segmentations[y].length; x++) {
        const label = segmentations[y][x];
        if (selectedSegmentations.includes(label)) {
          const i = (y * canvas.width + x) * 4;
          imageData.data[i] = 255; // R
          imageData.data[i + 1] = 255; // G
          imageData.data[i + 2] = 255; // B
          imageData.data[i + 3] = 255; // A
        }
      }
    }
    context.putImageData(imageData, 0, 0);

    // Render pen strokes
    if (linesLayerRef.current) {
      const linesLayer = linesLayerRef.current;
      const linesCanvas = linesLayer.toCanvas();
      const linesContext = linesCanvas.getContext("2d");
      const linesImageData = linesContext.getImageData(
        0,
        0,
        linesCanvas.width,
        linesCanvas.height
      );
      const linesData = linesImageData.data;

      // Calculate scaling factors
      const scaleX = linesCanvas.width / canvas.width;
      const scaleY = linesCanvas.height / canvas.height;

      // Combine segmentation mask with pen strokes
      for (let i = 0; i < imageData.data.length; i += 4) {
        const canvasIndex = i / 4;
        const canvasX = canvasIndex % canvas.width;
        const canvasY = Math.floor(canvasIndex / canvas.width);

        // Scale coordinates to match linesCanvas
        const linesX = Math.floor(canvasX * scaleX);
        const linesY = Math.floor(canvasY * scaleY);
        const linesIndex = (linesY * linesCanvas.width + linesX) * 4;

        // If pixel is not black (0,0,0) in lines, make it white in the final mask
        if (
          linesData[linesIndex] > 0 ||
          linesData[linesIndex + 1] > 0 ||
          linesData[linesIndex + 2] > 0
        ) {
          imageData.data[i] = 255; // Red
          imageData.data[i + 1] = 255; // Green
          imageData.data[i + 2] = 255; // Blue
          imageData.data[i + 3] = 255; // Alpha
        }
        // If pixel is black in lines but white in segmentation, keep it white
        else if (imageData.data[i] === 255) {
          // Do nothing, keep the white pixel from segmentation
        }
        // If pixel is black in both, make it black in the final mask
        else {
          imageData.data[i] = 0; // Red
          imageData.data[i + 1] = 0; // Green
          imageData.data[i + 2] = 0; // Blue
          imageData.data[i + 3] = 255; // Alpha
        }
      }

      // Put the combined image data back to the context
      context.putImageData(imageData, 0, 0);
    }

    // // Create a download link for the mask image
    // const link = document.createElement("a");
    // link.download = "seg_mask.jpg";
    // link.href = canvas.toDataURL();
    // link.click();

    return canvas.toDataURL().split(",")[1];
  };

  const handleGenerateInpainting = async () => {
    setLoading(true);
    const imageCaption = await handleCaption();
    if (!imageCaption) {
      setLoading(false);
      return;
    }
    const selectedItem = selectedItems[selectedImageIndex];
    const encodedImage = await encodeBlobToBase64(
      selectedItem.inpainted[selectedItem.inpainted.length - 1]
    );

    const encodedMaskImage = await extractMask();

    const resizedEncodedMaskImage = await resizeMaskImage(
      encodedImage,
      encodedMaskImage
    );

    // const blobImage = base64ToBlob(encodedImage, "image/png");
    // const blobMaskImage = base64ToBlob(resizedEncodedMaskImage, "image/png");
    // saveBlobToFile(blobImage, "image.png");
    // saveBlobToFile(blobMaskImage, "mask.png");

    const payload = {
      body: {
        prompt: imageCaption + " " + prompt,
        negative_prompt: NEGATIVE_PROMPT,
        input_image: encodedImage,
        mask_image: resizedEncodedMaskImage,
      },
    };

    try {
      const restOperation = post({
        apiName: "openaiAPI",
        path: "/resia-gen",
        options: {
          body: payload,
          headers: {
            "Content-Type": "application/json",
          },
        },
      });

      const { body } = await restOperation.response;
      const response = await body.json();
      const responseImage = JSON.parse(response.body).image;
      const imageObjectURL = `data:image/png;base64,${responseImage}`;

      setSelectedItems((prevSelectedItems) => {
        const newSelectedItems = [...prevSelectedItems];
        const newItem = { ...newSelectedItems[selectedImageIndex] };
        newItem.inpainted = [...newItem.inpainted, imageObjectURL];
        newSelectedItems[selectedImageIndex] = newItem;
        return newSelectedItems;
      });

      setImageViewURL(imageObjectURL);
    } catch (error) {
      setErrorMessage({
        heading: "Error Generating Inpainted Image",
        message:
          "An error occurred while generating the image. Please try again.",
      });
    } finally {
      setErrorMessage(null);
      setLoading(false);
      setShowSearchModal(false);
    }
  };

  return (
    <Flex
      className="create-flex"
      direction="column"
      alignItems="center"
      borderRadius="20px"
      padding="20px"
      gap="20px"
    >
      {errorMessage && (
        <Alert
          variation="error"
          isDismissible={true}
          hasIcon={true}
          heading={errorMessage.heading}
          onDismiss={() => setErrorMessage(null)}
        >
          {errorMessage.message}
        </Alert>
      )}
      {showExInfo && (
        <Alert variation="info" hasIcon={true}>
          You included an example image in the previous step. <br />
          To proceed the next step, please upload your own image and unselect
          the example image by clicking the heart.
        </Alert>
      )}
      <Flex direction="row" alignItems="start" gap="20px">
        <Flex direction="column" alignContent="center" justifyContent="center">
          {selectedImageIndex !== null && (
            <Flex direction="column" borderRadius="medium" variation="elevated">
              {viewComponent === "slider" ? (
                <CompareSlider
                  originalImage={selectedItems[selectedImageIndex].inpainted[0]}
                  newImage={
                    selectedItems[selectedImageIndex].inpainted[
                      selectedItems[selectedImageIndex].inpainted.length - 1
                    ]
                  }
                />
              ) : (
                <ImageDrawingCanvas
                  imageUrl={imageViewURL}
                  stageRef={stageRef}
                  linesLayerRef={linesLayerRef}
                  setHasDrawnLines={setHasDrawnLines}
                  segmentations={segmentations}
                  setSegmentations={setSegmentations}
                  selectedSegmentations={selectedSegmentations}
                  setSelectedSegmentations={setSelectedSegmentations}
                  handleGetSegmentResult={handleGetSegmentResult}
                  loadingSegmentation={loadingSegmentation}
                />
              )}
              {loadingSegmentation && (
                <Flex direction="column" gap="0">
                  <Text className="animated-gradient-text" fontSize="1rem">
                    Processing Image for AI Mask Selection
                  </Text>
                  <Loader variation="linear" filledColor="green.60" />
                </Flex>
              )}

              <Flex direction="row" justifyContent="space-between">
                <SwitchField
                  label="Compare with original"
                  isDisabled={isCompareDisabled}
                  onChange={handleViewComponentChange}
                  trackCheckedColor="green.60"
                  marginLeft="10px"
                ></SwitchField>
                <Flex
                  direction="row"
                  gap="20px"
                  alignItems="center"
                  marginRight="10px"
                >
                  <Text fontWeight="600">Replace selected mask with</Text>
                  <Button
                    className="animated-gradient"
                    width="140px"
                    marginLeft="-10px"
                    onClick={handleSearchButton}
                    borderColor="transparent"
                    color="white"
                    borderRadius="20px"
                    disabled={loading || viewComponent === "slider"}
                  >
                    Search
                  </Button>
                </Flex>
              </Flex>
            </Flex>
          )}
          <ScrollView
            width="100%"
            style={{ overflowX: "scroll", whiteSpace: "nowrap" }}
          >
            <Flex direction="row" gap="20px" alignItems="center">
              {selectedItems.length > 0 ? (
                <>
                  {selectedItems.map((item, index) => (
                    <Image
                      key={index}
                      src={item.inpainted[item.inpainted.length - 1]}
                      alt={`Selected ${index}`}
                      onClick={() => handleImageClick(index)}
                      style={{ cursor: "pointer", height: "100px" }}
                      borderRadius="20px"
                      marginRight="10px"
                    />
                  ))}
                </>
              ) : (
                <Text>Loading...</Text>
              )}
            </Flex>
          </ScrollView>
        </Flex>

        <ImageSearchModal
          showSearchModal={showSearchModal}
          setShowSearchModal={setShowSearchModal}
          prompt={prompt}
          setPrompt={setPrompt}
          handleGenerateInpainting={handleGenerateInpainting}
          selectedCategory={selectedItems[selectedImageIndex].category}
          product={product}
          setProduct={setProduct}
          loading={loading}
          errorMessage={errorMessage}
          setErrorMessage={setErrorMessage}
        />
      </Flex>
    </Flex>
  );
};

export default CustomizePage;
