[2D Vision] 2D Point Tracking: co-tracker ์ฌ์ฉ๋ฒ
๐ก ๋ณธ ๋ฌธ์๋ '[2D Vision] 2D Point Tracking: co-tracker ์ฌ์ฉ๋ฒ'์ ๋ํด ์ ๋ฆฌํด๋์ ๊ธ์ ๋๋ค.
๊ฐ๋จํ๊ฒ ์ฌ์ฉํ๊ธฐ ์ข์ Point Tracking ๋ชจ๋์ธ co-tracker๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ ๋ฆฌํ์์ผ๋ ์ฐธ๊ณ ํ์๊ธฐ ๋ฐ๋๋๋ค.
1. co-tracker ์ฌ์ฉ๋ฒ: Quick Start
github ์์ ๋จ์ํ๊ฒ ์ธ๊ธํ co-tracker ์ฌ์ฉ๋ฐฉ๋ฒ์ ๋๋ค. ์ฌ๊ธฐ ๋ณด์ด๋ฏ์ด torch.hub.load๋ฅผ ํตํด checkpoint๋ฅผ ๋ถ๋ฌ์ฌ ๊ฒฝ์ฐ ๋ฐ๋ก co-tracker๋ฅผ git clone ํด์ ์ฌ์ฉํ ํ์์์ด cotracker ํจ์๋ฅผ ํตํด ๋ถ๋ฌ์์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
import torch
# Download the video
url = 'https://github.com/facebookresearch/co-tracker/raw/refs/heads/main/assets/apple.mp4'
import imageio.v3 as iio
frames = iio.imread(url, plugin="FFMPEG") # plugin="pyav"
device = 'cuda'
grid_size = 10
video = torch.tensor(frames).permute(0, 3, 1, 2)[None].float().to(device) # B T C H W
# Run Offline CoTracker:
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to(device)
pred_tracks, pred_visibility = cotracker(video, grid_size=grid_size) # B T N 2, B T N 1
์ถ๊ฐ๋ก ๊ฒฐ๊ณผ๋ฅผ ์๊ฐํ ํ๊ธฐ ์ํด์๋ co-tracker/cotracker/utils/visualizer.py์ ํ์ผ์ ๋ณต์ฌํด์ ์ฌ์ฉํ๊ณ ์ถ์ ๊ณณ์ visualizer.py๋ก ๋ถ์ฌ๋ฃ๊ธฐ ํ ํ์ ๊ฒฐ๊ณผ๋ฅผ ์๊ฐํํด์ ํ์ธํ ์ ์์ต๋๋ค.
from visualizer import Visualizer
visualizer = Visualizer(
save_dir=output_cotracker,
fps=10,
mode="rainbow", # rainbow or optical_flow
tracks_leave_trace=15, # tracks leave trace
)
res_video = visualizer.visualize(
video=video*255,
tracks=pred_tracks,
visibility=pred_visibility if 'visibility_tensor' in locals() else None,
filename="pred_tracks_visualization",
save_video=True
)
์์ ๋ชจ๋ ๊ณผ์ ์ ์๊ฐํํด๋์ ํ์ผ์ co-tracker/demo.py ์์น์ ์์ผ๋ ์ฐธ๊ณ ํ์๊ธฐ ๋ฐ๋๋๋ค.
+ Utils
์ถ๊ฐ์ ์ผ๋ก ํ์ฌ์ ์ฝ๋ ๋ด์์๋ mp4 ํ์ผ์ ์ฝ์ด๋ค๊ฐ ์ฌ์ฉํ๋ ์ฝ๋๋ง ์กด์ฌํฉ๋๋ค. ๋ฐ๋ผ์ visualizeํ๋ ์์น์ธ co-tracker/cotracker/utils/visualizer.py์ ์๋์ ํจ์๋ฅผ ์ถ๊ฐํ๋ฉด sequence image๋ฅผ ์ฝ์ด ์ด๋ฅผ ๋ฐํ์ผ๋ก trackerํ ์ ์์ต๋๋ค.
def read_image_sequence_from_path(images_path):
# Supported image file extensions
img_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp"]
img_files = []
# Collect image files with supported extensions
for ext in img_extensions:
img_files.extend(glob.glob(os.path.join(images_path, ext)))
# Return None if no images are found
if not img_files:
print(f"No images found in path: {images_path}")
return None
# Sort files alphabetically to ensure proper sequence order
img_files = sorted(img_files)
frames = []
for img_path in img_files:
# Open each image and convert to RGB format
img = Image.open(img_path).convert("RGB")
frames.append(np.array(img))
# Stack all images into a single numpy array (T, H, W, C)
return np.stack(frames)
์ด ํจ์๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด์๋ ์๋์ ๊ฐ์ด demo.py ํ์ผ์ ์์ ํ์๋ฉด ๋ฉ๋๋ค.
args = parser.parse_args()
# load the input video frame by frame
if args.video_path.endswith(".mp4"):
video = read_video_from_path(args.video_path)
else:
video = read_image_sequence_from_path(args.video_path)
2. co-tracker ์ฌ์ฉ๋ฒ: Module ๋ณ๊ฒฝ ํ ํ์ฉ
TODO...
์ฐธ๊ณ
[Github] co-tracker: https://github.com/facebookresearch/co-tracker