Study: Artificial Intelligence(AI)/AI: 2D Vision(Det, Seg, Trac)

[2D Vision] 2D Point Tracking: co-tracker ์‚ฌ์šฉ๋ฒ•

DrawingProcess 2025. 4. 24. 12:11
๋ฐ˜์‘ํ˜•
๐Ÿ’ก ๋ณธ ๋ฌธ์„œ๋Š” '[2D Vision] 2D Point Tracking: co-tracker ์‚ฌ์šฉ๋ฒ•'์— ๋Œ€ํ•ด ์ •๋ฆฌํ•ด๋†“์€ ๊ธ€์ž…๋‹ˆ๋‹ค.
๊ฐ„๋‹จํ•˜๊ฒŒ ์‚ฌ์šฉํ•˜๊ธฐ ์ข‹์€ Point Tracking ๋ชจ๋“ˆ์ธ co-tracker๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์ •๋ฆฌํ•˜์˜€์œผ๋‹ˆ ์ฐธ๊ณ ํ•˜์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.

1. co-tracker ์‚ฌ์šฉ๋ฒ•:  Quick Start

github ์—์„œ ๋‹จ์ˆœํ•˜๊ฒŒ ์–ธ๊ธ‰ํ•œ co-tracker ์‚ฌ์šฉ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ์—ฌ๊ธฐ ๋ณด์ด๋“ฏ์ด torch.hub.load๋ฅผ ํ†ตํ•ด checkpoint๋ฅผ ๋ถˆ๋Ÿฌ์˜ฌ ๊ฒฝ์šฐ ๋”ฐ๋กœ co-tracker๋ฅผ git clone ํ•ด์„œ ์‚ฌ์šฉํ•  ํ•„์š”์—†์ด cotracker ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ๋ถˆ๋Ÿฌ์™€์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 

import torch
# Download the video
url = 'https://github.com/facebookresearch/co-tracker/raw/refs/heads/main/assets/apple.mp4'

import imageio.v3 as iio
frames = iio.imread(url, plugin="FFMPEG")  # plugin="pyav"

device = 'cuda'
grid_size = 10
video = torch.tensor(frames).permute(0, 3, 1, 2)[None].float().to(device)  # B T C H W

# Run Offline CoTracker:
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to(device)
pred_tracks, pred_visibility = cotracker(video, grid_size=grid_size) # B T N 2,  B T N 1

์ถ”๊ฐ€๋กœ ๊ฒฐ๊ณผ๋ฅผ ์‹œ๊ฐํ™” ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” co-tracker/cotracker/utils/visualizer.py์˜ ํŒŒ์ผ์„ ๋ณต์‚ฌํ•ด์„œ ์‚ฌ์šฉํ•˜๊ณ  ์‹ถ์€ ๊ณณ์— visualizer.py๋กœ ๋ถ™์—ฌ๋„ฃ๊ธฐ ํ•œ ํ›„์— ๊ฒฐ๊ณผ๋ฅผ ์‹œ๊ฐํ™”ํ•ด์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from visualizer import Visualizer
    visualizer = Visualizer(
        save_dir=output_cotracker,
        fps=10,
        mode="rainbow",         # rainbow or optical_flow
        tracks_leave_trace=15,  # tracks leave trace
    )
    res_video = visualizer.visualize(
        video=video*255,
        tracks=pred_tracks,
        visibility=pred_visibility if 'visibility_tensor' in locals() else None,
        filename="pred_tracks_visualization",
        save_video=True
    )

์œ„์˜ ๋ชจ๋“  ๊ณผ์ •์„ ์‹œ๊ฐํ™”ํ•ด๋†“์€ ํŒŒ์ผ์€ co-tracker/demo.py ์œ„์น˜์— ์žˆ์œผ๋‹ˆ ์ฐธ๊ณ ํ•˜์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.

+ Utils

์ถ”๊ฐ€์ ์œผ๋กœ ํ˜„์žฌ์˜ ์ฝ”๋“œ ๋‚ด์—์„œ๋Š” mp4 ํŒŒ์ผ์„ ์ฝ์–ด๋‹ค๊ฐ€ ์‚ฌ์šฉํ•˜๋Š” ์ฝ”๋“œ๋งŒ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ visualizeํ•˜๋Š” ์œ„์น˜์ธ co-tracker/cotracker/utils/visualizer.py์— ์•„๋ž˜์˜ ํ•จ์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•˜๋ฉด sequence image๋ฅผ ์ฝ์–ด ์ด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ trackerํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

def read_image_sequence_from_path(images_path):
    # Supported image file extensions
    img_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp"]
    img_files = []
    
    # Collect image files with supported extensions
    for ext in img_extensions:
        img_files.extend(glob.glob(os.path.join(images_path, ext)))

    # Return None if no images are found
    if not img_files:
        print(f"No images found in path: {images_path}")
        return None

    # Sort files alphabetically to ensure proper sequence order
    img_files = sorted(img_files)

    frames = []
    for img_path in img_files:
        # Open each image and convert to RGB format
        img = Image.open(img_path).convert("RGB")
        frames.append(np.array(img))

    # Stack all images into a single numpy array (T, H, W, C)
    return np.stack(frames)

์ด ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์•„๋ž˜์™€ ๊ฐ™์ด demo.py ํŒŒ์ผ์„ ์ˆ˜์ •ํ•˜์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค.

    args = parser.parse_args()
    # load the input video frame by frame
    if args.video_path.endswith(".mp4"):
        video = read_video_from_path(args.video_path)
    else:
        video = read_image_sequence_from_path(args.video_path)

2. co-tracker ์‚ฌ์šฉ๋ฒ•:  Module ๋ณ€๊ฒฝ ํ›„ ํ™œ์šฉ

TODO...

์ฐธ๊ณ 

[Github] co-tracker: https://github.com/facebookresearch/co-tracker

 

๋ฐ˜์‘ํ˜•