import cv2
import numpy as np
from model.base_model.ascnedcl.classes import CLASSES


def letterbox(img, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):
    # Resize and pad image while meeting stride-multiple constraints
    shape = img.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:  # only scale down, do not scale up (for better val mAP)
        r = min(r, 1.0)

    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

    if center:
        dw /= 2  # divide padding into 2 sides
        dh /= 2

    if shape[::-1] != new_unpad:  # resize
        img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)) if center else 0, int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)) if center else 0, int(round(dw + 0.1))
    img = cv2.copyMakeBorder(
        img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
    )  # add border

    return img, ratio, dw, dh

def non_max_suppression_v10(prediction,conf_thres,ratio,dw,dh):
    result = []
    for i in range(prediction.shape[0]):
        data = prediction[i]
        # 读取类别置信度
        confidence = data[4]
        # 用阈值进行过滤
        if confidence > conf_thres:
            # 读取类别索引
            label = int(data[5])
            # 读取类坐标值,把坐标还原到原始图像
            xmin = int((data[0] - int(round(dw - 0.1))) / ratio[0])
            ymin = int((data[1] - int(round(dh - 0.1))) / ratio[1])
            xmax = int((data[2] - int(round(dw + 0.1))) / ratio[0])
            ymax = int((data[3] - int(round(dh + 0.1))) / ratio[1])
            result.append([xmin, ymin, xmax, ymax, confidence, label])
    return result


def draw_bbox_old(bbox, img0, color, wt):
    det_result_str = ''
    for idx, class_id in enumerate(bbox[:, 5]):
        if float(bbox[idx][4] < float(0.05)):
            continue
        img0 = cv2.rectangle(img0, (int(bbox[idx][0]), int(bbox[idx][1])), (int(bbox[idx][2]), int(bbox[idx][3])), color, wt)
        img0 = cv2.putText(img0, str(idx) + ' ' + CLASSES[int(class_id)], (int(bbox[idx][0]), int(bbox[idx][1] + 16)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
        img0 = cv2.putText(img0, '{:.4f}'.format(bbox[idx][4]), (int(bbox[idx][0]), int(bbox[idx][1] + 32)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
        det_result_str += '{} {} {} {} {} {}\n'.format(CLASSES[bbox[idx][5]], str(bbox[idx][4]), bbox[idx][0], bbox[idx][1], bbox[idx][2], bbox[idx][3])
    return img0

def draw_box(img,
             box,  # [xmin, ymin, xmax, ymax]
             score,
             class_id):
    '''Draws a bounding box on the image'''

    # Retrieve the color for the class ID
    color_palette = np.random.uniform(0, 255, size=(len(CLASSES), 3))
    color = color_palette[class_id]

    # Draw the bounding box on the image
    cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), color, 2)

    # Create the label text with class name and score
    label = f'{CLASSES[class_id]}: {score:.2f}'

    # Calculate the dimensions of the label text
    (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)

    # Calculate the position of the label text
    label_x = box[0]
    label_y = box[1] - 10 if box[1] - 10 > label_height else box[1] + 10

    # Draw a filled rectangle as the background for the label text
    cv2.rectangle(
        img,
        (int(label_x), int(label_y - label_height)),
        (int(label_x + label_width), int(label_y + label_height)),
        color,
        cv2.FILLED,
    )

    # Draw the label text on the image
    cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)

    return img