from PIL import Image
import torch
import torchvision.transforms as T

def create_pyramid(img, N, r):
    # img: PIL Image, N: số tầng, r: tỉ lệ downsample
    pyramid = []
    w, h = img.size
    for n in range(N+1):
        size = (int(w * (r ** (N-n))), int(h * (r ** (N-n))))
        img_n = img.resize(size, Image.LANCZOS)
        pyramid.append(T.ToTensor()(img_n).unsqueeze(0))  # [1, C, H, W]
    return pyramid  # [x_0, ..., x_N] 