import { alignment, cbct, oss } from "@promaton/api-client";
import { AlignmentTask } from "@promaton/api-client";
import { applyCBCTConvention } from "@promaton/file-processing/src/conventionInternalCBCT";
import { applyUniversalConvention } from "@promaton/file-processing/src/conventionUniversal";
import {
  CameraUtils,
  FileType,
  useObjects,
  ViewerObject,
  ViewerObjectMap,
} from "@promaton/scan-viewer";
import { Matrix4 } from "three";

import { getAiAssistantHeaders } from "../../hooks/useAiAssistantHeaders";
import {
  AiTaskType,
  useAiAssistantState,
} from "../../stores/useAiAssistantState";
import { parseAiTaskMetadata } from "./aiTaskMetadata";
import { loadCbctResults } from "./loadCbctResults";
import { loadOssTaskResults } from "./loadOssTaskResults";

export const alignmentMatrixToTransform = (
  input: number[][],
  out = new Matrix4()
) => {
  out.fromArray(input.flat()).transpose();
  return out;
};

export const loadAlignmentTaskResults = async (
  task: AlignmentTask,
  load = true,
  displayResultProgress = true
) => {
  if (!task.result) return;
  const metadata = parseAiTaskMetadata(task.meta_data);
  const options = getAiAssistantHeaders();

  if (!metadata.related) return;

  const allObjects: ViewerObjectMap = {};
  const isFusion =
    task.result.structures_fused &&
    Object.keys(task.result.structures_fused).length > 0;

  for (const related of metadata.related) {
    if (related.type === AiTaskType.CBCT) {
      const cbctTask = await cbct.getCbct(related.id, options);
      if (cbctTask.data) {
        const objects = await loadCbctResults(
          cbctTask.data,
          false,
          displayResultProgress
        );
        if (objects) {
          Object.values(objects).forEach((o) => {
            if (isFusion && o.group === "Teeth") {
              o.hidden = true;
            }
            o.group = `CBCT ${o.group || ""}`;
          });
          Object.assign(allObjects, objects);
        }
      }
    } else if (related.type === AiTaskType.OPTICAL) {
      const matrix = new Matrix4();
      const matrixIndex = Math.max(
        task.result.transformation_matrices.findIndex(
          (i) => i?.task_id === related.id
        ),
        0
      );

      const data =
        task.result.transformation_matrices[matrixIndex]?.transformation_matrix;

      data && alignmentMatrixToTransform(data, matrix);

      const ossTask = await oss.getOss(related.id, options);
      if (ossTask.data) {
        const objects = await loadOssTaskResults(
          ossTask.data,
          false,
          displayResultProgress
        );
        if (objects) {
          Object.values(objects).forEach((o) => {
            o.transform = matrix.clone();
            if (isFusion && o.group === "Teeth") {
              o.hidden = true;
            }
            o.group = `IOS ${o.group || ""}`;
          });
          Object.assign(allObjects, objects);
        }
      }
    }
  }

  if (task.result.structures_fused && isFusion) {
    let resultLoadingProgress = 0.0001;
    const structures = Object.keys(task.result.structures_fused.stl);
    const fusionStls = await Promise.all(
      structures.map((structure) =>
        alignment
          .getAlignmentStlStructure(task.id, "fused", structure, options)
          .then(
            (t) =>
              [
                `FUSED_TOOTH_${structure.toUpperCase()}`,
                {
                  url: URL.createObjectURL(t.data),
                  objectType: FileType.STL,
                  clipToPlanes: true,
                },
              ] as [string, ViewerObject]
          )
          .finally(() => {
            resultLoadingProgress += 1 / structures.length;
            displayResultProgress &&
              useAiAssistantState.setState({ resultLoadingProgress });
          })
      )
    );

    const objects = Object.fromEntries(fusionStls) as ViewerObjectMap;
    applyUniversalConvention(objects);
    applyCBCTConvention(objects);
    Object.values(objects).forEach((o) => {
      o.group = `Fused Teeth`;
    });
    Object.assign(allObjects, objects);
  }

  if (load) {
    useObjects.getState().setObjects(allObjects, false);
    setTimeout(() => {
      CameraUtils.recenterAllViews();
    }, 200);
  }

  return allObjects;
};
