111 lines
3.1 KiB
Python
111 lines
3.1 KiB
Python
import cv2
|
|
import pymysql
|
|
from datetime import datetime
|
|
from ultralytics import YOLO
|
|
from deep_sort_realtime.deepsort_tracker import DeepSort
|
|
|
|
# -------------------------------
|
|
# CONFIG
|
|
# -------------------------------
|
|
RTSP_URL = "pagi.mp4"
|
|
LINE_A = 200
|
|
LINE_B = 500
|
|
|
|
DB_CONFIG = {
|
|
'host': 'localhost',
|
|
'user': 'root',
|
|
'password': '',
|
|
'database': 'db_traffic2'
|
|
}
|
|
|
|
# -------------------------------
|
|
# Setup
|
|
# -------------------------------
|
|
cap = cv2.VideoCapture(RTSP_URL)
|
|
model = YOLO("yolov8n.pt") # Downloads automatically
|
|
tracker = DeepSort(max_age=30)
|
|
|
|
# MySQL Connection
|
|
db = pymysql.connect(**DB_CONFIG)
|
|
cursor = db.cursor()
|
|
|
|
tracked_ids = {} # {track_id: {'y': prev_y, 'line': 'A' or 'B'}}
|
|
|
|
while cap.isOpened():
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
print("RTSP stream ended or cannot be read.")
|
|
break
|
|
|
|
# -------------------------------
|
|
# YOLO Detection
|
|
# -------------------------------
|
|
results = model(frame, verbose=False)[0]
|
|
detections = []
|
|
|
|
for box in results.boxes:
|
|
x1, y1, x2, y2 = box.xyxy[0]
|
|
conf = float(box.conf[0])
|
|
cls_id = int(box.cls[0])
|
|
class_name = model.names[cls_id]
|
|
|
|
if class_name in ["car", "bus", "truck", "motorbike"]:
|
|
bbox = [x1.item(), y1.item(), x2.item() - x1.item(), y2.item() - y1.item()]
|
|
detections.append((bbox, conf, class_name))
|
|
|
|
# -------------------------------
|
|
# DeepSORT Tracking
|
|
# -------------------------------
|
|
tracks = tracker.update_tracks(detections, frame=frame)
|
|
|
|
for track in tracks:
|
|
if not track.is_confirmed():
|
|
continue
|
|
|
|
track_id = track.track_id
|
|
ltrb = track.to_ltrb()
|
|
x1, y1, x2, y2 = map(int, ltrb)
|
|
center_y = (y1 + y2) // 2
|
|
|
|
prev = tracked_ids.get(track_id)
|
|
tracked_ids[track_id] = {'y': center_y}
|
|
|
|
if prev:
|
|
prev_y = prev['y']
|
|
|
|
# Determine direction and line crossing
|
|
if prev_y < LINE_A and center_y >= LINE_A:
|
|
direction = 'down'
|
|
now = datetime.now()
|
|
elif prev_y > LINE_B and center_y <= LINE_B:
|
|
direction = 'up'
|
|
now = datetime.now()
|
|
else:
|
|
continue
|
|
print(f"[{now}] Vehicle {track_id} crossed {direction}")
|
|
|
|
# Save to DB
|
|
sql = "INSERT INTO vehicle_crossings (vehicle_id, direction, crossing_time) VALUES (%s, %s, %s)"
|
|
cursor.execute(sql, (f"vehicle_{track_id}", direction, now))
|
|
db.commit()
|
|
|
|
# Draw box and label
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
cv2.putText(frame, f'ID {track_id}', (x1, y1 - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
|
|
|
# Draw reference lines
|
|
cv2.line(frame, (0, LINE_A), (frame.shape[1], LINE_A), (255, 0, 0), 2)
|
|
cv2.line(frame, (0, LINE_B), (frame.shape[1], LINE_B), (0, 0, 255), 2)
|
|
|
|
# Show frame
|
|
cv2.imshow("YOLO + DeepSORT", frame)
|
|
if cv2.waitKey(1) == 27: # ESC to quit
|
|
break
|
|
|
|
# Cleanup
|
|
cap.release()
|
|
cv2.destroyAllWindows()
|
|
cursor.close()
|
|
db.close()
|