Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading Custom ReID models and Poor Performance #1618

Open
1 task done
MahejabeenNidhi opened this issue Sep 9, 2024 · 2 comments
Open
1 task done

Loading Custom ReID models and Poor Performance #1618

MahejabeenNidhi opened this issue Sep 9, 2024 · 2 comments
Labels
question Further information is requested Stale

Comments

@MahejabeenNidhi
Copy link

Search before asking

  • I have searched the Yolo Tracking issues and found no similar bug report.

Question

Thank you so much for your work. I have two questions.

Question 1

I trained a ReID model with torchreid and it gives me a .pth file. I notices for StrongSORT repo, it's always a .pt file. Would I be able to just use the .pth file that my code below generates? It is especially important for my case as I am tracking non-human/non-vehicle objects.

import os
import torch
import string
import random
import argparse
import torchreid
from glob import glob
import os.path as osp


class NewDataset(torchreid.data.datasets.ImageDataset):
    dataset_dir = ''

    def __init__(self, root='', **kwargs):
        self.train_dir = self.dataset_dir
        self.query_dir = self.dataset_dir
        self.gallery_dir = self.dataset_dir

        train = self.process_dir(self.train_dir, isQuery=False)
        query = self.process_dir(self.query_dir, isQuery=True)
        gallery = self.process_dir(self.gallery_dir, isQuery=False)

        super(NewDataset, self).__init__(train, query, gallery, **kwargs)

    def process_dir(self, dir_path, isQuery, relabel=True):
        img_paths = glob(osp.join(dir_path, '*.jpg'))

        pid_container = set()
        for img_path in img_paths:
            img_name = img_path.split('/')[-1]
            name_splitted = img_name.split('_')
            pid = int(name_splitted[1][1:])
            pid_container.add(pid)

        pid2label = {pid: label for label, pid in enumerate(pid_container)}

        data = []
        for img_path in img_paths:
            img_name = img_path.split('/')[-1]
            name_splitted = img_name.split('_')
            pid = int(name_splitted[1][1:])
            camid = int(name_splitted[0][1:])

            if isQuery:
                camid += 10  # index starts from 0

            if relabel:
                pid = pid2label[pid]

            data.append((img_path, pid, camid))

        return data


def get_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', type=str, default='osnet_x1_0', help="ReID model name")
    parser.add_argument('--img_h', type=int, default=256, help="image height")
    parser.add_argument('--img_w', type=int, default=128, help="image width")
    parser.add_argument('--bs', type=int, default=32, help="batch size")
    parser.add_argument('--optim', type=str, default='adam', help="optimizer")
    parser.add_argument('--lr', type=float, default=0.003, help="learning rate")
    parser.add_argument('--lr_sch', type=str, default="single_step", help="learning rate scheduler")
    parser.add_argument('--step', type=int, default=5, help="learning rate scheduler's step size")
    parser.add_argument('--epochs', type=int, default=20, help="epoch count for the training loop")
    parser.add_argument('--eval_freq', type=int, default=5, help="evaluation frequency")
    parser.add_argument('--data_path', type=str, default='path/to/data', help="path to the custom dataset")
    parser.add_argument('--save_path', type=str, default='path/to/save', help="path to save the model")

    args = parser.parse_args()

    return args


def main(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    NewDataset.dataset_dir = args.data_path
    dataset_name = ''.join(random.choices(string.ascii_uppercase + string.digits, k=random.randint(1, 25)))
    torchreid.data.register_image_dataset(dataset_name, NewDataset)

    datamanager = torchreid.data.ImageDataManager(
        sources=dataset_name,
        height=args.img_h,
        width=args.img_w,
        batch_size_train=args.bs,
        batch_size_test=100,
        transforms=["random_flip", "random_crop"]
    )

    model = torchreid.models.build_model(
        name=args.name,
        num_classes=datamanager.num_train_pids,
        loss="triplet",
        pretrained=True
    ).to(device).train()

    optimizer = torchreid.optim.build_optimizer(
        model,
        optim=args.optim,
        lr=args.lr,
    )

    scheduler = torchreid.optim.build_lr_scheduler(
        optimizer,
        lr_scheduler=args.lr_sch,
        stepsize=args.step,
    )

    engine = torchreid.engine.ImageTripletEngine(
        datamanager,
        model,
        optimizer=optimizer,
        scheduler=scheduler,
        margin=0.3,  # by default 0.3
        weight_t=1,  # weight for triplet loss
        weight_x=50,  # weight for softmax loss
    )

    engine.run(
        save_dir=f"log/{args.name}",
        max_epoch=args.epochs,
        eval_freq=args.eval_freq,
        print_freq=50,
        test_only=False
    )

    # Save the trained model
    model_save_path = os.path.join(args.save_path, f"{args.name}_model.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Trained model saved at: {model_save_path}")


if __name__ == '__main__':
    args = get_parser()
    main(args)

Question 2

I used the built in tracker in YOLOv8 and the performance was much better. I used a custom weight for my unique class and it worked a lot better than when I used this repo.

python tracking/track.py --source ../LabelledTracking/D05-AA-01_LM --yolo-model tracking/weights/yolov8_best.pt --tracking-method botsort --imgsz 2160 --save --save-txt

Better performance when using the following code,

import cv2
import numpy as np
from ultralytics import YOLO
import os
import random

# Function to generate a random RGB color
def random_color():
    return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

# Function to draw bounding boxes and center points
def draw_boxes_and_centers(frame, boxes, clss, track_ids, confs, img_size, object_colors):
    img = frame.copy()
    centers = []

    for i, box in enumerate(boxes):
        cls_id = int(clss[i])
        track_id = int(track_ids[i])
        conf = confs[i]

        x1, y1, x2, y2 = map(int, box)
        center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2

        # Get color for the object (generate a new color if not assigned)
        if track_id not in object_colors:
            object_colors[track_id] = random_color()
        color = object_colors[track_id]

        # Draw bounding box
        cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)

        # Draw track ID and center point
        label = f"{track_id}"
        cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
        cv2.circle(img, (center_x, center_y), 3, color, -1)

        centers.append((center_x / img_size[0], center_y / img_size[1], track_id))

    return img, centers, object_colors

# Function to create an image showing center point tracks
def create_tracks_image(centers, img_size, object_colors):
    track_img = np.ones((img_size[1], img_size[0], 3), dtype=np.uint8) * 255
    track_centers = {}

    for center_x, center_y, track_id in centers:
        if track_id not in track_centers:
            track_centers[track_id] = []

        track_centers[track_id].append((int(center_x * img_size[0]), int(center_y * img_size[1])))

    for track_id, points in track_centers.items():
        color = object_colors[track_id]
        for i in range(len(points) - 1):
            cv2.line(track_img, points[i], points[i + 1], color, 2)
        cv2.circle(track_img, points[-1], 3, color, -1)  # Draw the last point as a circle

    return track_img

output_directory = "../../path"
os.makedirs(output_directory, exist_ok=True)

model_path = '../../runs/detect/train17/weights/best.pt'
model = YOLO(model_path)  # Load a custom trained model
names = model.model.names

image_folder = "../../dataset"
assert os.path.exists(image_folder), "Image folder not found"

object_colors = {}
all_centers = []
frame_count = 0
for filename in sorted(os.listdir(image_folder)):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
        img_path = os.path.join(image_folder, filename)
        frame = cv2.imread(img_path)
        assert frame is not None, f"Failed to read image {filename}"

        h, w, _ = frame.shape
        img_size = (w, h)

        # Extract prediction results
        results = model.track(frame, persist=True, verbose=False)
        boxes = results[0].boxes.xyxy.cpu().numpy()
        clss = results[0].boxes.cls.cpu().tolist()
        track_ids = results[0].boxes.id.int().cpu().tolist()
        confs = results[0].boxes.conf.float().cpu().tolist()

        # Draw bounding boxes and center points
        annotated_frame, centers, object_colors = draw_boxes_and_centers(frame, boxes, clss, track_ids, confs, img_size, object_colors)
        all_centers.extend(centers)

        frame_filename = f"output_{frame_count}.jpg"
        frame_path = os.path.join(output_directory, frame_filename)
        cv2.imwrite(frame_path, annotated_frame)

        frame_count += 1

The detections when I use Boxmot look very poor comparatively, making me wonder if the weight loaded properly. What would you advise?

Thank you so much for your time!

@MahejabeenNidhi MahejabeenNidhi added the question Further information is requested label Sep 9, 2024
@mikel-brostrom
Copy link
Owner

mikel-brostrom commented Sep 9, 2024

Would I be able to just use the .pth file that my code below generates?

Yup, no problem. Just change the suffix from .pth to .pt

Copy link

👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs.
Feel free to inform us of any other issues you discover or feature requests that come to mind in the future. Pull Requests (PRs) are also always welcomed!

@github-actions github-actions bot added the Stale label Sep 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested Stale
Projects
None yet
Development

No branches or pull requests

2 participants