Pytorch中的仿射變換(affine_grid)

在看 pytorch 的 Spatial Transformer Network 教程 時,在 stn 層中的 affine_gridgrid_sample 函數(shù)上卡住了,不知道這兩個函數(shù)該如何使用,經過一些實驗終于搞清楚了其作用。

參考:詳細解讀Spatial Transformer Networks (STN),該文章與李宏毅的課程一樣,推薦聽李老師的 STN 這一課,講的比較清楚;

假設我們有這么一張圖片:


魁拔中的卡拉肖克·玲

下面我們將通過分別通過手動編碼和pytorch方式對該圖片進行平移、旋轉、轉置、縮放等操作,這些操作的數(shù)學原理在本文中不會詳細講解。

實現(xiàn)載入圖片(注意,下面的代碼都是在 jupyter 中進行):

from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

%matplotlib inline

img_path = "圖片文件路徑"
img_torch = transforms.ToTensor()(Image.open(img_path))

plt.imshow(img_torch.numpy().transpose(1,2,0))
plt.show()
圖片載入

平移操作

普通方式

例如我們需要向右平移50px,向下平移100px。

import numpy as np
import torch

theta = np.array([
    [1,0,50],
    [0,1,100]
])
# 變換1:可以實現(xiàn)縮放/旋轉,這里為 [[1,0],[0,1]] 保存圖片不變
t1 = theta[:,[0,1]]
# 變換2:可以實現(xiàn)平移
t2 = theta[:,[2]]

_, h, w = img_torch.size()
new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
for x in range(w):
    for y in range(h):
        pos = np.array([[x], [y]])
        npos = t1@pos+t2
        nx, ny = npos[0][0], npos[1][0]
        if 0<=nx<w and 0<=ny<h:
            new_img_torch[:,ny,nx] = img_torch[:,y,x]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

圖片變?yōu)椋?/p>

圖片平移-1

pytorch 方式

向右移動0.2,向下移動0.4:

from torch.nn import functional as F

theta = torch.tensor([
    [1,0,-0.2],
    [0,1,-0.4]
], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

得到的圖片為:


圖片平移-2

總結:

  • 要使用 pytorch 的平移操作,只需要兩步:
    • 創(chuàng)建 grid:grid = torch.nn.functional.affine_grid(theta, size),其實我們可以通過調節(jié) size 設置所得到的圖像的大小(相當于resize);
    • grid_sample 進行重采樣:outputs = torch.nn.functional.grid_sample(inputs, grid, mode='bilinear')
  • theta 的第三列為平移比例,向右為負,向下為負;

我們通過設置 size 可以將圖像resize:

from torch.nn import functional as F

theta = torch.tensor([
    [1,0,-0.2],
    [0,1,-0.4]
], dtype=torch.float)
# 修改size
N, C, W, H = img_torch.unsqueeze(0).size()
size = torch.Size((N, C, W//2, H//3))
grid = F.affine_grid(theta.unsqueeze(0), size)
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()
修改size的效果

縮放操作

普通方式

放大1倍:

import numpy as np
import torch

theta = np.array([
    [2,0,0],
    [0,2,0]
])
t1 = theta[:,[0,1]]
t2 = theta[:,[2]]

_, h, w = img_torch.size()
new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
for x in range(w):
    for y in range(h):
        pos = np.array([[x], [y]])
        npos = t1@pos+t2
        nx, ny = npos[0][0], npos[1][0]
        if 0<=nx<w and 0<=ny<h:
            new_img_torch[:,ny,nx] = img_torch[:,y,x]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

結果為:


放大操作-1

由于沒有使用插值算法,所以中間有很多部分是黑色的。

pytorch 方式

from torch.nn import functional as F

theta = torch.tensor([
    [0.5, 0  , 0],
    [0  , 0.5, 0]
], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

結果為:

放大操作-2

結論:可以看到,affine_grid 的放大操作是以圖片中心為原點的。

旋轉操作

普通操作

將圖片旋轉30度:

import numpy as np
import torch
import math

angle = 30*math.pi/180
theta = np.array([
    [math.cos(angle),math.sin(-angle),0],
    [math.sin(angle),math.cos(angle) ,0]
])
t1 = theta[:,[0,1]]
t2 = theta[:,[2]]

_, h, w = img_torch.size()
new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
for x in range(w):
    for y in range(h):
        pos = np.array([[x], [y]])
        npos = t1@pos+t2
        nx, ny = int(npos[0][0]), int(npos[1][0])
        if 0<=nx<w and 0<=ny<h:
            new_img_torch[:,ny,nx] = img_torch[:,y,x]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

結果為:
旋轉操作-1

pytorch 操作

from torch.nn import functional as F
import math

angle = -30*math.pi/180
theta = torch.tensor([
    [math.cos(angle),math.sin(-angle),0],
    [math.sin(angle),math.cos(angle) ,0]
], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

結果為:


旋轉操作-2

pytorch 以圖片中心為原點進行旋轉,并且在旋轉過程中會發(fā)生圖片縮放,如果選擇角度變?yōu)?90°,圖片為:


旋轉 90° 結果

轉置操作

普通操作

import numpy as np
import torch

theta = np.array([
    [0,1,0],
    [1,0,0]
])
t1 = theta[:,[0,1]]
t2 = theta[:,[2]]

_, h, w = img_torch.size()
new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
for x in range(w):
    for y in range(h):
        pos = np.array([[x], [y]])
        npos = t1@pos+t2
        nx, ny = npos[0][0], npos[1][0]
        if 0<=nx<w and 0<=ny<h:
            new_img_torch[:,ny,nx] = img_torch[:,y,x]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

結果為:


圖片轉置-1

pytorch 操作

我們可以通過size大小,保存圖片不被壓縮:

from torch.nn import functional as F

theta = torch.tensor([
    [0, 1, 0],
    [1, 0, 0]
], dtype=torch.float)
N, C, H, W = img_torch.unsqueeze(0).size()
grid = F.affine_grid(theta.unsqueeze(0), torch.Size((N, C, W, H)))
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

結果為:


圖片轉置-2

上面就是 affine_grid + grid_sample 的大致用法,如果你在看 STN 時有相同的用法,希望可以幫助到你。

?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容