import { acd, cbct, oss } from "@promaton/api-client";
import {
  CameraUtils,
  FileType,
  getNameFromID,
  isImageFile,
  MeshName,
  Rotate,
  StockTool,
  useObjects,
  useTools,
  useViewerContext,
} from "@promaton/scan-viewer";
import { Box3, Euler, Matrix4, Mesh, Vector3 } from "three";
import { create } from "zustand";

import { getCbctCommonTaskDataUpload } from "../components/AiAssistantCbctCommon";
import { createAiTaskMetadataPayload } from "../helpers/ai/aiTaskMetadata";
import {
  COPILOT_SUCCESS,
  CopilotScanType,
  CopilotSceneContentDescription,
  CopilotStatusMessage,
} from "../helpers/copilotPrompts";
import { getAiAssistantHeaders } from "../hooks/useAiAssistantHeaders";
import {
  AI_TASK_RETENTION_TIME_MINUTES,
  AiTaskType,
  useAiAssistantState,
} from "./useAiAssistantState";
import { useAppState, ViewMode } from "./useAppState";

export type CopilotMessage = {
  author: "user" | "assistant" | "function";
  content: string;
};

export const useCopilot = create<{
  messages: CopilotMessage[];
  addMessage: (message: CopilotMessage) => void;
  thread: string | null;
  setThread: (thread: string | null) => void;
  restart: () => void;
}>((set) => ({
  messages: [],
  addMessage: (message) => {
    set((s) => ({
      messages: [...s.messages, message],
    }));
  },
  thread: null,
  setThread: (thread) => {
    set(() => ({
      thread,
    }));
  },
  restart: () => {
    set(() => ({
      messages: [],
      thread: null,
    }));
  },
}));

/** Implementation of functions the Copilot API may call. */
export const copilotFunctions: Record<string, (input: any) => object> = {
  getObjects: (_: any) => {
    const objects = useObjects.getState().objects;
    let upperJawSegmentations = false;
    let lowerJawSegmentations = false;
    const fdiMap: Record<number, string> = {};

    const objectsList = Object.entries(objects).map(([key, value]) => {
      const name = getNameFromID(key);
      const upperLocation = name.match(/[1-2][1-8]$/i);
      const lowerLocation = name.match(/[3-4][1-8]$/i);
      if (upperLocation) {
        upperJawSegmentations = true;
        const fdi = parseInt(upperLocation[0]);
        fdiMap[fdi] = value?.group || "Tooth";
      } else if (lowerLocation) {
        lowerJawSegmentations = true;
        const fdi = parseInt(lowerLocation[0]);
        fdiMap[fdi] = value?.group || "Tooth";
      }

      return { id: key, group: value?.group };
    });

    for (let fdiQuadrant = 1; fdiQuadrant <= 4; fdiQuadrant++) {
      if (fdiQuadrant >= 3 && !lowerJawSegmentations) continue;
      if (fdiQuadrant <= 2 && !upperJawSegmentations) continue;
      for (let tooth = 1; tooth <= 8; tooth++) {
        if (!fdiMap[fdiQuadrant * 10 + tooth]) {
          // Explicitly mention which teeth are missing from the scan, as GPT is bad at inferring this.
          fdiMap[fdiQuadrant * 10 + tooth] = tooth === 8 ? "N/A" : "Missing";
        }
      }
    }

    const scanType = Object.values(objects).find((o) =>
      isImageFile(o?.objectType)
    )
      ? CopilotScanType.CBCT
      : CopilotScanType.INTRAORAL;

    const sceneDescription =
      upperJawSegmentations && lowerJawSegmentations
        ? CopilotSceneContentDescription.SEGMENTED_BOTH
        : upperJawSegmentations
        ? CopilotSceneContentDescription.SEGMENTED_UPPER
        : lowerJawSegmentations
        ? CopilotSceneContentDescription.SEGMENTED_LOWER
        : CopilotSceneContentDescription.SCAN_ONLY;

    return { objects: objectsList, scanType, sceneDescription, fdiMap };
  },
  modifyObjectConfigurations: (params: {
    objects: {
      id: string;
      hidden?: boolean;
      color?: string;
      opacity?: number;
    }[];
  }) => {
    const updateObject = useObjects.getState().updateObject;
    params.objects.forEach((object) => {
      const { id, ...rest } = object;
      updateObject(id, rest);
    });
    return COPILOT_SUCCESS;
  },
  selectObjectAndCenter: (params: { id: string }) => {
    const selectObject = useObjects.getState().setSelection;
    selectObject({ [params.id]: {} });
    CameraUtils.recenterAllViews();
    return COPILOT_SUCCESS;
  },
  setViewMode: ({ mode }: { mode: "3D" | "slice" | "orthogonal" }) => {
    const setViewMode = useAppState.getState().setViewMode;
    setViewMode(
      mode === "3D"
        ? ViewMode.SINGLE
        : mode === "orthogonal"
        ? ViewMode.SPLIT
        : mode === "slice"
        ? ViewMode.CUSTOM
        : ViewMode.SINGLE
    );
    return COPILOT_SUCCESS;
  },
  runSegmentation: async ({
    id,
    jaw,
  }: {
    id: string;
    jaw: "upper" | "lower" | "unknown";
  }) => {
    const object = useObjects.getState().objects[id];
    if (!object) {
      throw new Error("Object not found");
    }

    const addTask = useAiAssistantState.getState().addTask;

    if (object.objectType === FileType.DICOM) {
      const data = await getCbctCommonTaskDataUpload(id, object);
      const res = await cbct.createCbct(
        data,
        {
          meta_data: createAiTaskMetadataPayload({ name: id }),
          retain: AI_TASK_RETENTION_TIME_MINUTES,
        },
        getAiAssistantHeaders({ "Content-Type": "application/zip" })
      );
      addTask({ id: res.data.id, type: AiTaskType.CBCT });
    } else if (object.objectType === FileType.STL) {
      // Handle unknown jaw type
      if (jaw === "unknown") {
        throw new Error("Jaw type must be specified. Ask user for input.");
      }
      const url = Array.isArray(object.url) ? object.url[0] : object.url;
      const file = await (await fetch(url)).blob();
      const res = await oss.createOss(
        jaw,
        file,
        {
          meta_data: createAiTaskMetadataPayload({ name: id }),
          retain: AI_TASK_RETENTION_TIME_MINUTES,
        },
        getAiAssistantHeaders()
      );
      addTask({ id: res.data.id, type: AiTaskType.OPTICAL });
    } else {
      throw new Error("Unsupported file type");
    }

    return {
      status: "Task started successfully",
    } as CopilotStatusMessage;
  },
  runACD: async ({ id, tooth }: { id: string; tooth: number }) => {
    const object = useObjects.getState().objects[id];
    if (!object) {
      throw new Error("Object not found");
    }
    if (object.objectType !== FileType.STL) {
      throw new Error("Unsupported file type");
    }

    const addTask = useAiAssistantState.getState().addTask;
    const jaw = tooth.toString().match(/^[1256]/) ? "upper" : "lower";

    const url = Array.isArray(object.url) ? object.url[0] : object.url;
    const file = await (await fetch(url)).blob();
    const res = await acd.createAcd(
      jaw,
      [tooth.toString()],
      file,
      {
        meta_data: createAiTaskMetadataPayload({ name: id }),
        retain: AI_TASK_RETENTION_TIME_MINUTES,
      },
      getAiAssistantHeaders()
    );
    addTask({ id: res.data.id, type: AiTaskType.ACD });

    return {
      status: "Task started successfully",
    } as CopilotStatusMessage;
  },
  activateTool: ({ tool }: { tool: "transform" | "sculpt" | "measure" }) => {
    const toolMap = {
      transform: StockTool.TRANSFORM,
      sculpt: StockTool.SCULPT,
      measure: StockTool.MEASURE,
    } as const;
    useTools.getState().setActiveTool(toolMap[tool]);
    StockTool;
    return COPILOT_SUCCESS;
  },
  getObjectBounds: (_: any) => {
    const scene = useViewerContext.getState().scene;
    if (!scene) throw new Error("Scene not found");

    const ids = Object.keys(useObjects.getState().objects);
    const result: { [key: string]: object } = {};
    ids.forEach((id) => {
      const obj = scene.getObjectByName(id)?.getObjectByName(MeshName.MAIN);
      if (obj && obj instanceof Mesh) {
        const bounds = new Box3().setFromObject(obj, true);
        const size = bounds.getSize(new Vector3()).toArray();
        const center = bounds.getCenter(new Vector3()).toArray();
        const volume = size.reduce((a, b) => a * b, 1);
        result[id] = { size, center, volume, unit: "mm" };
      }
    });
    return result;
  },

  transformObject: ({
    objectIds,
    translation,
    rotation,
    scale,
  }: {
    objectIds: string[];
    translation?: number[];
    rotation?: number[];
    scale?: number;
  }) => {
    const matrix = new Matrix4().identity();
    if (rotation) {
      matrix.makeRotationFromEuler(
        new Euler(...rotation.map((i) => (i * Math.PI) / 180))
      );
    }
    if (translation) {
      translation[2] = -translation[2];
      matrix.setPosition(
        new Vector3(...translation).applyEuler(
          new Euler(...Rotate.Z_UP_TO_Y_UP)
        )
      );
    }
    if (scale) {
      matrix.multiply(new Matrix4().makeScale(scale, scale, scale));
    }

    const objectState = useObjects.getState();
    objectIds.forEach((id) => {
      const transform = objectState.objects[id]?.transform;
      const newTransform = transform
        ? transform.clone().multiply(matrix)
        : matrix.clone();
      objectState.updateObject(id, { transform: newTransform });
    });

    return COPILOT_SUCCESS;
  },
};
