import torch
from torch.serialization import add_safe_globals
from models.yolo import Model  # 這行要在 yolov5 專案目錄中執行
import cv2

# 從 PyTorch Hub 加載 YOLOv5 模型
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')

# 初始化攝影機
cap = cv2.VideoCapture(0)

if not cap.isOpened():
    print("無法開啟攝影機")
    exit()

# 設定影像解析度
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)

# 定義需要的類別
target_classes = ['person', 'cell phone']

while True:
    # 從攝影機讀取影像
    ret, frame = cap.read()
    if not ret:
        print("無法取得影像")
        break

    # OpenCV 的影像格式為 BGR，需要轉換為 RGB
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # 使用 YOLOv5 模型進行偵測
    results = model(rgb_frame)

    # 過濾目標類別
    detected_objects = results.pandas().xyxy[0]  # 取得偵測結果的 DataFrame
    filtered_objects = detected_objects[detected_objects['name'].isin(target_classes)]

    # 在影像上繪製過濾後的標註
    for _, row in filtered_objects.iterrows():
        x1, y1, x2, y2 = int(row['xmin']), int(row['ymin']), int(row['xmax']), int(row['ymax'])
        label = f"{row['name']} {row['confidence']:.2f}"
        # 繪製邊框與標籤
        cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

    # 顯示影像
    cv2.imshow('YOLOv5 Detection', frame)

    # 按下 'q' 鍵退出
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# 釋放攝影機並關閉視窗
cap.release()
cv2.destroyAllWindows()