CNN中的卷積操作

目錄:

  • 1.CNN中的卷積操作
    • 直接卷積法
    • 通用矩陣乘法GEMM
  • 2.手動(dòng)實(shí)現(xiàn)Conv2d

一、卷積神經(jīng)網(wǎng)絡(luò)中的卷積操作

直接卷積法

代碼實(shí)現(xiàn):

# 根據(jù)公式計(jì)算卷積的尺寸
def cal_convoluation_size(input, kernel, padding=0, stride=1, dilation=1):
    new_kernel = dilation * (kernel - 1) + 1  # 空洞卷積,空洞數(shù)為0時(shí)dilation=1
    # 根據(jù)公式計(jì)算輸出,并返回
    return math.floor((input + 2 * padding - new_kernel) / stride + 1)

# 簡(jiǎn)單版本的直接卷積法:不考慮padding,dilation=1,padding=0
def convoluation(image, kernel):
    image_height, image_width, channels = image.shape
    kernel_height, kernel_width = kernel.shape
    # 計(jì)算輸出的形狀大小
    out_height = cal_convoluation_size(image_height, kernel_height)
    out_width = cal_convoluation_size(image_width, kernel_width)
    output = np.zeros((out_height, out_width, channels))

    # 計(jì)算output的每個(gè)像素值
    # 先找到目標(biāo)圖(dx, dy)對(duì)應(yīng)原圖中的中心點(diǎn)位置(cx, cy),然后計(jì)算
    for dy in range(out_height):
        for dx in range(out_width):
            # 遍歷kernel計(jì)算輸出(output[dy, dx])的像素值
            for ky in range(kernel_height):
                for kx in range(kernel_width):
                    kernel_value = kernel[ky, kx]
                    pixel_value = image[dy + ky, dx + kx]
                    output[dy, dx] += kernel_value * pixel_value   
      
    return output
通用矩陣乘法GEMM

針對(duì)卷積速度慢的問題,使用GEMM進(jìn)行優(yōu)化。
(還可以對(duì)GEMM進(jìn)一步優(yōu)化,感興趣的同學(xué)可以自行去了解下Winograd算法。)

GEMM的核心思想是img2col。img2col的流程如下:

代碼實(shí)現(xiàn):

# 根據(jù)公式計(jì)算卷積的尺寸
def cal_convoluation_size(input, kernel, padding=0, stride=1, dilation=1):
    new_kernel = dilation * (kernel - 1) + 1  # 空洞卷積,空洞數(shù)為0時(shí)dilation=1
    # 根據(jù)公式計(jì)算輸出,并返回
    return math.floor((input + 2 * padding - new_kernel) / stride + 1)

# 定義gemm卷積函數(shù):先定義一個(gè)簡(jiǎn)單版本的,不考慮padding、stride、dilation
# images-->(N, C, H, W), kernels-->(out_channels, in_channels, kh, kw), 且 C = in_channels
# 輸出結(jié)果output-->(N, out_channels, output_height, output_width)
def gemm(images, kernels, padding=0, stride=1, dilation=1):
    N, C, H, W = images.shape
    out_channels, in_channels, kh, kw = kernels.shape
    
    # 1.kernels轉(zhuǎn)換為col: (out_channel, in_channel * kh * kw)
    kernel_col = kernels.reshape(out_channels, -1)
    
    # 2.img轉(zhuǎn)換為col
    # 計(jì)算輸出的形狀大小
    out_height = cal_convoluation_size(H, kh, padding, stride, dilation)
    out_width = cal_convoluation_size(W, kw, padding, stride, dilation)
    # img_col的行數(shù)、列數(shù)
    kernel_count = kh * kw
    rows, cols = in_channels * kernel_count, out_height * out_width
    
    # 將圖片的數(shù)量N放在高維,這樣GEMM得到的結(jié)果不用再通過切片去拿
    img_col = np.zeros((N, rows, cols))
    for i in range(N):  # 第幾張圖片
        for idy in range(out_height):
            for idx in range(out_width):
                col_index = idy * out_width + idx
                for ic in range(C):  # C=in_channels
                    for iky in range(kh):
                        for ikx in range(kw):
                            row_index = ic * kernel_count + iky * kw + ikx                            # 賦值
                            img_col[i, row_index, col_index] = images[i, ic, idy + iky, idx + ikx]
    
    # 3.卷積計(jì)算之GEMM方法
    # (out_channels, in_channels * kh * kw) @ (N, in_channels * kh * kw, out_height * out_width)
    # = (N, out_channels, out_height * out_width)
    output = kernel_col @ img_col
    return output.reshape(N, out_channels, out_height, out_width)

二、手動(dòng)實(shí)現(xiàn)Conv2d

反向傳播時(shí),需要將對(duì)columns的梯度轉(zhuǎn)換為對(duì)輸入image的梯度,即還要實(shí)現(xiàn)一個(gè)col2img。

代碼實(shí)現(xiàn):

# 2D卷積
class Conv2d(Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = (kernel_size, kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias

        # 權(quán)重初始化:Xavier初始化、Kaiming初始化
        # fan_in = in_channels * kh * kw, fan_out = out_channels * kh * kw
        fan_in = in_channels * kernel_size * kernel_size
        bound = 1 / math.sqrt(fan_in)
        gain = math.sqrt(2)  # ReLU
        self.weight = Parameter(
            np.random.normal(0, gain * bound, size=(out_channels, in_channels, kernel_size, kernel_size)))
        self.bias = Parameter(np.random.uniform(-bound, bound, size=(out_channels,)))

    def forward(self, input):
        # 已添加padding和stride的邏輯,暫時(shí)不考慮dilation
        # 仔細(xì)思考了一下,加dilation不難,邏輯稍微修改一下即可:
        #  1)加dilation,只需將kernel變換一下即可,中間補(bǔ)0即可?!形磧?yōu)化
        self.input = input  # save for backward
        N, _, H, W = input.shape
        kh, kw = self.kernel_size

        # 計(jì)算輸出的形狀大小
        self.out_height = self.cal_convoluation_size(H, kh, self.padding, self.stride, self.dilation)  # save for backward
        self.out_width = self.cal_convoluation_size(W, kw, self.padding, self.stride, self.dilation)  # save for backward

        # kernel轉(zhuǎn)換為col
        self.kernel_col = self.weight.data.reshape(self.out_channels, -1)  # save for backward

        # img轉(zhuǎn)換為col
        self.columns = self.img2col(input, (self.out_channels, self.in_channels, kh, kw),
            (self.out_height, self.out_width), self.padding, self.stride, self.dilation)  # save for backward

        # 卷積計(jì)算之GEMM方法
        # (out_channels, in_channels * kh * kw) @ (N, in_channels * kh * kw, out_height * out_width)
        # = (N, out_channels, out_height * out_width)
        output = self.kernel_col @ self.columns + self.bias.data[..., None]

        # (N, out_channels, out_height * out_width) --> (N, out_channels, out_height, out_width)
        return output.reshape(N, self.out_channels, self.out_height, self.out_width)

    def backward(self, delta):
        '''
        反向計(jì)算weight和bias的梯度,同時(shí)計(jì)算并返回"誤差對(duì)輸入的"誤差項(xiàng)
        delta:反向傳遞過來的"誤差對(duì)輸出的"誤差項(xiàng)
        '''
        # (N, out_channels, out_height, out_width) --> (N, out_channels, out_height * out_width)
        delta = delta.reshape(len(delta), self.out_channels, -1)

        # 計(jì)算對(duì)weight的梯度
        # (N, out_channels, out_height * out_width) @ (N, out_height * out_width, in_channels * kh * kw)
        # = (N, out_channels, in_channels * kh * kw) --> (out_channels, in_channels * kh * kw)
        kernel_col_grad = np.sum(delta @ np.transpose(self.columns, axes=(0, 2, 1)), axis=0)  # 所有樣本對(duì)weight的梯度相加
        # (out_channels, in_channels * kh * kw) --> (out_channels, in_channels, kh, kw)
        self.weight.grad += kernel_col_grad.reshape(self.out_channels, self.in_channels, *self.kernel_size)

        # 計(jì)算對(duì)bias的梯度
        # (N, out_channels, out_height * out_width) --> (out_channels,)
        self.bias.grad += np.sum(delta, axis=(0, 2))  # 所有樣本對(duì)bias的梯度相加

        # 計(jì)算并返回"誤差對(duì)輸入的"誤差項(xiàng)
        # (in_channels * kh * kw, out_channels) @ (N, out_channels, out_height * out_width)
        # = (N, in_channels * kh * kw, out_height * out_width)
        columns_delta = self.kernel_col.T @ delta
        return self.delta_col2img(columns_delta, self.input.shape,
                                  (self.out_channels, self.in_channels, *self.kernel_size),
                                  (self.out_height, self.out_width), self.padding, self.stride, self.dilation)

    # 根據(jù)公式計(jì)算卷積的尺寸
    def cal_convoluation_size(self, input, kernel, padding=0, stride=1, dilation=1):
        new_kernel = dilation * (kernel - 1) + 1  # 空洞卷積,空洞數(shù)為0時(shí)dilation=1
        # 根據(jù)公式計(jì)算輸出,并返回
        return math.floor((input + 2 * padding - new_kernel) / stride + 1)

    # 將img2col從gemm中抽離出來,方便forward和backward
    def img2col(self, images, kernel_shape, out_shape, padding=0, stride=1, dilation=1):
        # 考慮padding
        N, C, H, W = images.shape
        new_images = np.zeros((N, C, H + 2 * padding, W + 2 * padding))  # 周圍padding用0填充
        new_images[:, :, padding:H + padding, padding:W + padding] = images
        
        out_channels, in_channels, kh, kw = kernel_shape
        out_height, out_width = out_shape

        # img_col的行數(shù)、列數(shù)
        kernel_count = kh * kw
        rows, cols = in_channels * kernel_count, out_height * out_width

        # 將圖片的數(shù)量N放在高維,這樣GEMM得到的結(jié)果不用再通過切片去拿  
        columns = np.zeros((N, cols, rows))
        for idy in range(out_height):
            for idx in range(out_width):
                col_index = idy * out_width + idx
                start_y = self.stride * idy
                start_x = self.stride * idx
                columns[:, col_index] = new_images[:, :, start_y:start_y + kh, start_x:start_x + kw].reshape(N, -1)
                    
        return columns.transpose(0, 2, 1)

    def delta_col2img(self, columns_delta, input_shape, kernel_shape, out_shape, padding=0, stride=1, dilation=1):
        '''
        columns_delta: (N, in_channels * kh * kw, out_height * out_width)
        input_shape: (N, C, H, W)
        kernel_shape: (out_channels, in_channels, kh, kw)
        out_shape: (out_height, out_width)
        '''
        N, C, H, W = input_shape
        out_channels, in_channels, kh, kw = kernel_shape
        out_height, out_width = out_shape

        # 考慮padding
        images_delta = np.zeros((N, C, H + 2 * padding, W + 2 * padding))
        for i in range(N):  # 第幾張圖片
            for idy in range(out_height):
                for idx in range(out_width):
                    col_index = idy * out_width + idx
                    column_delta = columns_delta[i, :, col_index]  # (in_channels * kh * kw,)
                    # (in_channels * kh * kw,) --> (in_channels, kh, kw)
                    column_delta = column_delta.reshape(in_channels, kh, kw)
                    
                    # 將每一列的delta疊加到原圖對(duì)應(yīng)位置中
                    for ic, kernel_delta in enumerate(column_delta):
                        for iky, kh_delta in enumerate(kernel_delta):
                            for ikx, kw_delta in enumerate(kh_delta):
                                # 考慮stride
                                images_delta[i, ic, stride * idy + iky, stride * idx + ikx] += column_delta[ic, iky, ikx]

        # 考慮padding,去除外圍的padding
        return images_delta[:, :, padding:H + padding, padding:W + padding]
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容