import * as tf from '@tensorflow/tfjs';
import { setWasmPaths } from '@tensorflow/tfjs-backend-wasm';
import * as log from 'src/utils/logger';

export default class AIModel {
  prompt_text = 'Starting AI Analysis...';
  numClass = 3;
  modelSegHeight = 160;
  modelSegWidth = 160;
  modelSegChannel = 32;
  modelSize = 640;
  maskThreshold = 4;
  distance_threshold = 5;
  length_threshold = 0.35;
  xRatio: any;
  yRatio: any;
  t_prediction: string | null | number = 0;
  promptHistory: string[] = [];
  tfmodel: any;

  async loadGraph() {
    tf.enableProdMode();
    this.tfmodel = await tf.loadGraphModel('/model/web_model_first/model.json');
  }

  async loadModel() {
    try {
      if (process.env.REACT_APP_TF_WASM === '1') {
        setWasmPaths({
          'tfjs-backend-wasm.wasm': '/tfjs-backend-wasm.wasm',
          'tfjs-backend-wasm-simd.wasm': '/tfjs-backend-wasm-simd.wasm',
          'tfjs-backend-wasm-threaded-simd.wasm':
            '/tfjs-backend-wasm-threaded-simd.wasm',
        });
        this.maskThreshold = 5;
        await tf.setBackend('wasm');
        await this.loadGraph();
      } else {
        await this.loadGraph();
        tf.env().set('WEBGL_DELETE_TEXTURE_THRESHOLD', 50000000);
      }
    } catch (e) {
      log.error({ event: 'Error loading model', data: { error: e } });
    }
  }

  printAllocatedTensors() {
    console.log(tf.memory());
  }

  preprocess(source: any) {
    let xRatio, yRatio; // ratios for boxes
    const input = tf.tidy(() => {
      const img = tf.browser.fromPixels(source);

      // padding image to square => [n, m] to [n, n], n > m
      const [h, w] = img.shape.slice(0, 2); // get source width and height

      const maxSize = Math.max(w, h); // get max size
      const imgPadded: any = img.pad([
        [0, maxSize - h], // padding y [bottom only]
        [0, maxSize - w], // padding x [right only]
        [0, 0],
      ]);

      xRatio = maxSize / w; // update xRatio
      yRatio = maxSize / h; // update yRatio

      return tf.image
        .resizeBilinear(imgPadded, [this.modelSize, this.modelSize]) // resize frame
        .div(255.0) // normalize
        .expandDims(0); // add batch
    });
    return [input, xRatio, yRatio];
  }

  async segmentImage(imageData: any) {
    let tfmodel = this.tfmodel;

    if (!tfmodel) {
      throw new Error('Model Not Defined');
    }
    tf.engine().startScope();
    this.printAllocatedTensors();

    this.promptHistory = [];

    const [input, xRatio, yRatio] = this.preprocess(imageData); // do preprocessing

    this.xRatio = xRatio;
    this.yRatio = yRatio;

    const startPredict = performance.now(); // Start timing prediction
    // Process the tensor through the model
    const outputTensor = await tfmodel.predict(input);
    const endPredict = performance.now(); // End timing prediction
    this.t_prediction = `${(endPredict - startPredict).toFixed(2)} ms`;

    let result = await this.postprocessSegmentation(outputTensor);
    // Printing the number of tensors allocated at this time
    this.printAllocatedTensors();
    tf.dispose([input, outputTensor]);
    tf.engine().endScope();
    return result;
  }

  async postprocessSegmentation(outputTensor: any) {
    const transRes = tf.tidy(() =>
      outputTensor[0].transpose([0, 2, 1]).squeeze()
    );
    const boxes: any = tf.tidy(() => {
      const w = transRes.slice([0, 2], [-1, 1]);
      const h = transRes.slice([0, 3], [-1, 1]);
      const x1 = tf.sub(transRes.slice([0, 0], [-1, 1]), tf.div(w, 2)); //x1
      const y1 = tf.sub(transRes.slice([0, 1], [-1, 1]), tf.div(h, 2)); //y1
      return tf
        .concat(
          [
            y1,
            x1,
            tf.add(y1, h), //y2
            tf.add(x1, w), //x2
          ],
          1
        ) // [y1, x1, y2, x2]
        .squeeze(); // [n, 4]
    }); // get boxes [y1, x1, y2, x2]

    const [scores, classes] = tf.tidy(() => {
      const rawScores = transRes.slice([0, 4], [-1, this.numClass]).squeeze(); // [n, 1]
      return [rawScores.max(1), rawScores.argMax(1)];
    }); // get scores and classes

    const nms = await tf.image.nonMaxSuppressionAsync(
      boxes,
      scores,
      20,
      0.5,
      0.1
    ); // do nms to filter boxes
    // const nms = await tf.image.nonMaxSuppressionAsync(boxes, scores, 100, 0.2, 0.1); // do nms to filter boxes

    const detReady = tf.tidy(() =>
      tf.concat(
        [
          boxes.gather(nms, 0),
          scores.gather(nms, 0).expandDims(1),
          classes.gather(nms, 0).expandDims(1),
        ],
        1 // axis
      )
    ); // indexing selected boxes, scores and classes from NMS result

    let layoutData = [];

    for (let i = 0; i < detReady.shape[0]; i++) {
      const rowData = detReady.slice([i, 0], [1, 6]).dataSync(); // get every first 6 elements from every row

      let [y1, x1, y2, x2, score, label] = rowData; // [y1, x1, y2, x2, score, label]

      if (score > 0.2) {
        if (y1 < 0) y1 = 0;
        if (x1 < 0) x1 = 0;
        if (y2 < 0) y2 = 0;
        if (x2 < 0) x2 = 0;

        const box = [x1, y1, x2, y2];
        const boxType = this.getTypeByBox(box);

        if (boxType !== 'off_edge') {
          if (this.validByBox(box))
            layoutData.push({
              label: label, // 0: red/wall, 1: blue/floor, 2: green/ceiling
              box: box,
              area: this.boxArea(box),
              type: boxType,
            });
        } else {
          // transpose segmentation mask result
          // to [32,160,160]
          const transSegMask = tf.tidy(() =>
            outputTensor[1].transpose([0, 3, 1, 2]).squeeze()
          );

          const masks = tf.tidy(() => {
            const sliced = transRes
              .slice([0, 4 + this.numClass], [-1, this.modelSegChannel])
              .squeeze(); // slice mask from every detection [m, mask_size]
            return sliced
              .gather(nms, 0) // get selected mask from NMS result
              .matMul(transSegMask.reshape([this.modelSegChannel, -1])) // matmul mask with segmentation mask result [n, mask_size] x [mask_size, h x w] => [n, h x w]
              .reshape([nms.shape[0], this.modelSegHeight, this.modelSegWidth]); // reshape back [n, h x w] => [n, h, w]
          }); // processing mask of shape [n,160,160]

          const downSampleBox = [
            Math.floor((y1 * this.modelSegHeight) / this.modelSize), // y
            Math.floor((x1 * this.modelSegWidth) / this.modelSize), // x
            Math.round(((y2 - y1) * this.modelSegHeight) / this.modelSize), // h
            Math.round(((x2 - x1) * this.modelSegWidth) / this.modelSize), // w
          ]; // downsampled box (box ratio at model output)
          const upSampleBox = [
            Math.floor(y1 * this.yRatio), // y
            Math.floor(x1 * this.xRatio), // x
            Math.round((y2 - y1) * this.yRatio), // h
            Math.round((x2 - x1) * this.xRatio), // w
          ]; // upsampled box (box ratio to draw)

          const proto = tf.tidy(() => {
            const sliced = masks.slice(
              [i, Math.max(downSampleBox[0], 0), Math.max(downSampleBox[1], 0)],
              [
                1,
                Math.min(
                  downSampleBox[2],
                  this.modelSegHeight - downSampleBox[0]
                ),
                Math.min(
                  downSampleBox[3],
                  this.modelSegWidth - downSampleBox[1]
                ),
              ]
            );
            return sliced.squeeze().expandDims(-1); // sliced proto [h, w, 1]
          });

          if (proto.shape.length > 2) {
            const upsampleProto = tf.image.resizeBilinear(proto, [
              upSampleBox[2],
              upSampleBox[3],
            ]); // resizing proto to drawing size
            const mask = tf.tidy(() => {
              const padded = upsampleProto.pad([
                [
                  upSampleBox[0],
                  this.modelSize - (upSampleBox[0] + upSampleBox[2]),
                ],
                [
                  upSampleBox[1],
                  this.modelSize - (upSampleBox[1] + upSampleBox[3]),
                ],
                [0, 0],
              ]); // padding proto to canvas size
              return padded.less(0.5); // make boolean mask from proto to indexing overlay
            }); // final boolean mask

            let area = this.maskArea(mask);
            if (area > 15000)
              // area threshold
              layoutData.push({
                label: label, // 0: red/wall, 1: blue/floor, 2: green/ceiling
                area: area,
                type: label === 0 ? this.getCorrectMask(mask) : 'no_wall',
              });
            tf.dispose([upsampleProto, mask, proto]); // dispose unused tensor to free memory
          }
        }
      }
      tf.dispose([rowData]); // dispose unused tensor to free memory
    }

    tf.dispose([transRes, boxes, scores, classes, nms, detReady]); // dispose unused tensor to free memory

    this.getImageMaskQA(layoutData);
    return this.promptHistory;
  }

  getCorrectMask(mask: any) {
    return this.getTypeByMask(mask);
  }

  boxArea(box: any) {
    const [x1, y1, x2, y2] = box;
    return (x2 - x1) * (y2 - y1);
  }

  validByBox(box: any, disThreshold = 50, areaThresold = 30000) {
    const [x1, y1, x2, y2] = box;
    let width = x2 - x1;
    let height = y2 - y1;

    if (width < disThreshold || height < disThreshold) return false;
    if (width * height < areaThresold) return false;

    return true;
  }

  getBoundingBox(mat: any) {
    let maskData = mat.arraySync();
    const [rows, cols] = mat.shape;

    let minRow = rows,
      maxRow = -1,
      minCol = cols,
      maxCol = -1;

    for (let r = 0; r < rows; r++) {
      for (let c = 0; c < cols; c++) {
        if (maskData[r][c][0] === 0) {
          minRow = Math.min(minRow, r);
          maxRow = Math.max(maxRow, r);
          minCol = Math.min(minCol, c);
          maxCol = Math.max(maxCol, c);
        }
      }
    }

    let boundingBoxHeight = maxRow - minRow + 1;
    let boundingBoxWidth = maxCol - minCol + 1;

    // Handle the case where no object (zero value) is found
    if (minRow === rows || minCol === cols) {
      return { height: 0, width: 0 };
    }

    return { height: boundingBoxHeight, width: boundingBoxWidth };
  }

  getTypeByMask(mat: any, threshold = this.maskThreshold) {
    let maskData = mat.arraySync();
    const [rows, cols] = mat.shape;

    // const matShape = mat.shape;
    let onTopEdge = false,
      onBottomEdge = false,
      onLeftEdge = false,
      onRightEdge = false;

    // Check top and bottom rows within the threshold
    for (let r = 0; r < threshold && (!onTopEdge || !onBottomEdge); r++) {
      for (let c = 0; c < cols; c++) {
        if (!onTopEdge && maskData[r][c][0] === 0) onTopEdge = true;
        if (!onBottomEdge && maskData[rows - 1 - r][c][0] === 0)
          onBottomEdge = true;
      }
    }

    // Check left and right columns within the threshold
    for (let c = 0; c < threshold && (!onLeftEdge || !onRightEdge); c++) {
      for (let r = 0; r < rows; r++) {
        if (!onLeftEdge && maskData[r][c][0] === 0) onLeftEdge = true;
        if (!onRightEdge && maskData[r][cols - 1 - c][0] === 0)
          onRightEdge = true;
      }
    }

    let edgeResults = [];
    if (onTopEdge) edgeResults.push('top');
    if (onBottomEdge) edgeResults.push('bottom');
    if (onLeftEdge) edgeResults.push('left');
    if (onRightEdge) edgeResults.push('right');

    return edgeResults.length > 0
      ? 'on_edge_' + edgeResults.join('_')
      : 'off_edge';
  }

  getTypeByBox(box: any, threshold = 4) {
    const [x1, y1, x2, y2] = box;
    let edgeResults = [];

    if (x1 < threshold) edgeResults.push('left');
    if (y1 < threshold) edgeResults.push('top');
    if (x2 > 640 - threshold) edgeResults.push('right');
    if (y2 > 640 - threshold) edgeResults.push('bottom');

    return edgeResults.length > 0
      ? 'on_edge_' + edgeResults.join('_')
      : 'off_edge';
  }

  maskArea(mask: any, valueToCount: any = 0) {
    return tf.tidy(() => {
      if (valueToCount === 1 || valueToCount === true) {
        return mask.sum().dataSync()[0];
      } else if (valueToCount === 0 || valueToCount === false) {
        return mask.logicalNot().sum().dataSync()[0];
      } else {
        throw new Error('Value to count must be 0 or 1');
      }
    });
  }

  getImageMaskQA(layoutData: any) {
    layoutData.sort((a: any, b: any) => b.area - a.area); // Sorting in descending order by area

    let walls = layoutData.filter((a: any) => a['label'] === 0);
    // let ceilings = layoutData.filter((a: any) => a['label'] == 2)
    // let floors = layoutData.filter((a: any) => a['label'] == 1)
    let offEdgeWalls = layoutData.filter(
      (a: any) => a['label'] === 0 && a['type'] === 'off_edge'
    );
    let onEdgeWalls = layoutData.filter(
      (a: any) => a['label'] === 0 && a['type'].indexOf('on_edge') > -1
    );

    this.prompt_text = 'This photo looks good! ';
    let targetWall = null;

    if (walls.length < 1) {
      this.addPrompt(
        'Uh-oh, this does not look like a wall! If it is a wall, please try to zoom-out, step back as far as possible, and try angling your camera differently to include all of the wall. '
      );
    } else {
      if (offEdgeWalls.length < 1) {
        targetWall = onEdgeWalls[0];
      } else {
        targetWall = offEdgeWalls[0];

        if (targetWall.area < this.modelSize * this.modelSize * 0.05) {
          if (onEdgeWalls.length > 0 && onEdgeWalls[0].area > targetWall.area) {
            targetWall = onEdgeWalls[0];
          }
        }
      }
    }

    if (targetWall) {
      if (targetWall.area < this.modelSize * this.modelSize * 0.05) {
        this.addPrompt(
          'Please try to get closer to the wall for better quality. '
        );
      } else if (targetWall.area > this.modelSize * this.modelSize * 0.9) {
        this.addPrompt(
          'Please try to zoom-out, step back as far as possible, and try angling your camera differently to include all of the wall. '
        );
      } else if (targetWall['type'].indexOf('on_edge') > -1) {
        let missing_sides = targetWall['type'].split('on_edge_')[1].split('_');

        if (missing_sides.length === 1) {
          if (missing_sides[0] === 'top')
            this.addPrompt(
              'Please try to zoom-out, step back as far as possible, and try angling your camera up to include the top edge of the wall. '
            );
          if (missing_sides[0] === 'bottom')
            this.addPrompt(
              'Please try to zoom-out, step back as far as possible, and try angling your camera down to include the bottom edge of the wall. '
            );
          if (missing_sides[0] === 'left')
            this.addPrompt(
              'Please try to zoom-out, step back as far as possible, and try angling your camera left to include the left edge of the wall. '
            );
          if (missing_sides[0] === 'right')
            this.addPrompt(
              'Please try to zoom-out, step back as far as possible, and try angling your camera right to include the right edge of the wall. '
            );
        } else {
          this.addPrompt(
            'Please try to zoom-out, step back as far as possible, and try angling your camera differently to include all of the wall. Missing edges: ' +
              missing_sides.join(', ') +
              '. '
          );
        }
      }
    }
    return this.promptHistory;
  }

  addPrompt(text: string) {
    this.promptHistory.push(text);
  }
}
