卷積學(xué)習(xí)筆記:Conv2D做MNIST手寫(xiě)數(shù)字識(shí)別

卷積可視化https://animatedai.github.io/

image.png
# 第1段:導(dǎo)入必要的庫(kù)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

print("PyTorch version:", torch.__version__)
print("Device:", "CUDA" if torch.cuda.is_available() else "CPU")
PyTorch version: 2.7.1
Device: CPU
# 第2段:加載MNIST數(shù)據(jù)集
def load_data():
    transform = transforms.ToTensor()
    
    train_dataset = torchvision.datasets.MNIST(root='./data', 
                                              train=True, 
                                              download=True, 
                                              transform=transform)
    
    test_dataset = torchvision.datasets.MNIST(root='./data', 
                                             train=False, 
                                             transform=transform)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, 
                                              batch_size=64, 
                                              shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(test_dataset, 
                                             batch_size=64, 
                                             shuffle=False)
    
    return train_loader, test_loader

# 加載數(shù)據(jù)
train_loader, test_loader = load_data()

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")
Training samples: 60000
Test samples: 10000
# 第3段:查看數(shù)據(jù)樣本
def show_sample_data(dataloader, num_samples=8):
    data, labels = next(iter(dataloader))
    
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for i in range(num_samples):
        ax = axes[i//4, i%4]
        ax.imshow(data[i][0], cmap='gray')
        ax.set_title(f'Label: {labels[i].item()}')
        ax.axis('off')
    
    plt.suptitle('Sample MNIST Images')
    plt.tight_layout()
    plt.show()
    
    print(f"Image shape: {data[0].shape}")  # [1, 28, 28]
    print(f"Batch shape: {data.shape}")     # [64, 1, 28, 28]

show_sample_data(train_loader)
output_2_0.png
Image shape: torch.Size([1, 28, 28])
Batch shape: torch.Size([64, 1, 28, 28])
# 第4段:定義CNN模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 第一層卷積: 1→16通道,5x5卷積核
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=2)
        self.pool1 = nn.MaxPool2d(2, 2)  # 28x28 → 14x14
        
        # 第二層卷積: 16→32通道,5x5卷積核
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=2)
        self.pool2 = nn.MaxPool2d(2, 2)  # 14x14 → 7x7
        
        # 全連接層
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)  # 10個(gè)數(shù)字類(lèi)別
        
        # 激活函數(shù)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        # 卷積層1
        x = self.conv1(x)  # [64, 1, 28, 28] → [64, 16, 28, 28]
        x = self.relu(x)
        x = self.pool1(x)  # [64, 16, 28, 28] → [64, 16, 14, 14]
        
        # 卷積層2
        x = self.conv2(x)  # [64, 16, 14, 14] → [64, 32, 14, 14]
        x = self.relu(x)
        x = self.pool2(x)  # [64, 32, 14, 14] → [64, 32, 7, 7]
        
        # 展平
        x = x.view(x.size(0), -1)  # [64, 32, 7, 7] → [64, 1568]
        
        # 全連接層
        x = self.fc1(x)    # [64, 1568] → [64, 128]
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)    # [64, 128] → [64, 10]
        
        return x

# 創(chuàng)建模型并查看參數(shù)
model = SimpleCNN()
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
SimpleCNN(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1568, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
)

Total parameters: 215,370
# 第5段:分析Conv2D權(quán)重形狀
print("=== Conv2D Weight Shape Analysis ===")
print(f"Conv1 weight shape: {model.conv1.weight.shape}")
print("Meaning: [16 output channels, 1 input channel, 5 height, 5 width]")

print(f"\nConv2 weight shape: {model.conv2.weight.shape}")
print("Meaning: [32 output channels, 16 input channels, 5 height, 5 width]")

print("\n=== Data Flow Through Model ===")
sample_input = torch.randn(1, 1, 28, 28)
print(f"Input shape: {sample_input.shape}")

with torch.no_grad():
    x = model.conv1(sample_input)
    print(f"After Conv1: {x.shape}")
    
    x = model.relu(x)
    x = model.pool1(x)
    print(f"After Pool1: {x.shape}")
    
    x = model.conv2(x)
    print(f"After Conv2: {x.shape}")
    
    x = model.relu(x)
    x = model.pool2(x)
    print(f"After Pool2: {x.shape}")
    
    x = x.view(x.size(0), -1)
    print(f"After Flatten: {x.shape}")
    
    x = model.fc1(x)
    print(f"After FC1: {x.shape}")
    
    x = model.fc2(x)
    print(f"Final Output: {x.shape}")
=== Conv2D Weight Shape Analysis ===
Conv1 weight shape: torch.Size([16, 1, 5, 5])
Meaning: [16 output channels, 1 input channel, 5 height, 5 width]

Conv2 weight shape: torch.Size([32, 16, 5, 5])
Meaning: [32 output channels, 16 input channels, 5 height, 5 width]

=== Data Flow Through Model ===
Input shape: torch.Size([1, 1, 28, 28])
After Conv1: torch.Size([1, 16, 28, 28])
After Pool1: torch.Size([1, 16, 14, 14])
After Conv2: torch.Size([1, 32, 14, 14])
After Pool2: torch.Size([1, 32, 7, 7])
After Flatten: torch.Size([1, 1568])
After FC1: torch.Size([1, 128])
Final Output: torch.Size([1, 10])
# 第6段:定義訓(xùn)練函數(shù)
def train_model(model, train_loader, epochs=3):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    model.train()
    train_losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            # 計(jì)算準(zhǔn)確率
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            # 每200個(gè)batch打印一次
            if batch_idx % 200 == 0:
                print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        avg_loss = epoch_loss / len(train_loader)
        accuracy = 100 * correct / total
        train_losses.append(avg_loss)
        
        print(f'Epoch {epoch+1} Complete: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%')
    
    return train_losses
# 第7段:訓(xùn)練模型
print("Starting training...")
train_losses = train_model(model, train_loader, epochs=3)

# 繪制訓(xùn)練損失
plt.figure(figsize=(8, 5))
plt.plot(train_losses, marker='o')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()
Starting training...
Epoch 1, Batch 0, Loss: 2.3039
Epoch 1, Batch 200, Loss: 0.3510
Epoch 1, Batch 400, Loss: 0.1626
Epoch 1, Batch 600, Loss: 0.0927
Epoch 1, Batch 800, Loss: 0.0329
Epoch 1 Complete: Loss=0.2248, Accuracy=93.09%
Epoch 2, Batch 0, Loss: 0.0343
Epoch 2, Batch 200, Loss: 0.0524
Epoch 2, Batch 400, Loss: 0.0231
Epoch 2, Batch 600, Loss: 0.0102
Epoch 2, Batch 800, Loss: 0.0253
Epoch 2 Complete: Loss=0.0660, Accuracy=97.89%
Epoch 3, Batch 0, Loss: 0.0603
Epoch 3, Batch 200, Loss: 0.0493
Epoch 3, Batch 400, Loss: 0.0388
Epoch 3, Batch 600, Loss: 0.0033
Epoch 3, Batch 800, Loss: 0.1216
Epoch 3 Complete: Loss=0.0472, Accuracy=98.53%
output_6_1.png
# 第8段:測(cè)試模型
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    class_correct = [0] * 10
    class_total = [0] * 10
    
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            # 每個(gè)數(shù)字的準(zhǔn)確率
            for i in range(target.size(0)):
                label = target[i]
                class_correct[label] += (predicted[i] == label).item()
                class_total[label] += 1
    
    print(f'Overall Test Accuracy: {100 * correct / total:.2f}%')
    
    print('\nPer-digit Recognition Accuracy:')
    for i in range(10):
        if class_total[i] > 0:
            acc = 100 * class_correct[i] / class_total[i]
            print(f'Digit {i}: {acc:.1f}% ({class_correct[i]}/{class_total[i]})')

test_model(model, test_loader)
Overall Test Accuracy: 98.94%

Per-digit Recognition Accuracy:
Digit 0: 99.4% (974/980)
Digit 1: 99.7% (1132/1135)
Digit 2: 99.6% (1028/1032)
Digit 3: 99.2% (1002/1010)
Digit 4: 97.7% (959/982)
Digit 5: 98.7% (880/892)
Digit 6: 99.1% (949/958)
Digit 7: 98.4% (1012/1028)
Digit 8: 98.3% (957/974)
Digit 9: 99.2% (1001/1009)
# 第9段:可視化預(yù)測(cè)結(jié)果
def visualize_predictions(model, test_loader, num_samples=16):
    model.eval()
    
    data, target = next(iter(test_loader))
    
    with torch.no_grad():
        output = model(data)
        _, predicted = torch.max(output, 1)
    
    fig, axes = plt.subplots(4, 4, figsize=(12, 12))
    for i in range(num_samples):
        ax = axes[i//4, i%4]
        ax.imshow(data[i][0], cmap='gray')
        
        true_label = target[i].item()
        pred_label = predicted[i].item()
        color = 'green' if true_label == pred_label else 'red'
        
        ax.set_title(f'True: {true_label}, Pred: {pred_label}', color=color)
        ax.axis('off')
    
    plt.suptitle('Prediction Results (Green=Correct, Red=Wrong)')
    plt.tight_layout()
    plt.show()

visualize_predictions(model, test_loader)
output_8_0.png
# 第10段:分析學(xué)習(xí)到的卷積核
def analyze_conv_kernels(model):
    conv1_weight = model.conv1.weight.data  # [16, 1, 5, 5]
    
    print(f"First layer kernels shape: {conv1_weight.shape}")
    print("These are the learned feature detectors!")
    
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    for i in range(16):
        ax = axes[i//8, i%8]
        kernel = conv1_weight[i, 0].numpy()
        im = ax.imshow(kernel, cmap='RdBu')
        ax.set_title(f'Kernel {i+1}')
        ax.axis('off')
    
    plt.suptitle('Learned Conv Kernels (Feature Detectors)')
    plt.tight_layout()
    plt.show()

analyze_conv_kernels(model)
First layer kernels shape: torch.Size([16, 1, 5, 5])
These are the learned feature detectors!
output_9_1.png
# 第11段:可視化特征圖
def visualize_feature_maps(model, test_loader):
    model.eval()
    
    data, _ = next(iter(test_loader))
    sample_image = data[0:1]  # 取第一張圖
    
    with torch.no_grad():
        # 第一層卷積后的特征圖
        x = model.conv1(sample_image)
        x = model.relu(x)
        feature_maps = x[0]  # [16, 28, 28]
    
    # 顯示原圖和16個(gè)特征圖
    fig, axes = plt.subplots(3, 6, figsize=(15, 8))
    
    # 原圖
    axes[0, 0].imshow(sample_image[0, 0], cmap='gray')
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    # 特征圖
    for i in range(16):
        if i < 5:  # 第一行剩余位置
            row, col = 0, i+1
        elif i < 11:  # 第二行
            row, col = 1, i-5
        else:  # 第三行
            row, col = 2, i-11
        
        axes[row, col].imshow(feature_maps[i], cmap='viridis')
        axes[row, col].set_title(f'Feature {i+1}')
        axes[row, col].axis('off')
    
    plt.suptitle('Feature Maps after First Conv Layer')
    plt.tight_layout()
    plt.show()

visualize_feature_maps(model, test_loader)
output_10_0.png
# 第12段:對(duì)比不同數(shù)字的特征激活
def compare_digit_features(model, test_loader):
    model.eval()
    
    # 找到不同數(shù)字的樣本
    digits_found = {}
    with torch.no_grad():
        for data, target in test_loader:
            for i, label in enumerate(target):
                digit = label.item()
                if digit not in digits_found and len(digits_found) < 5:
                    digits_found[digit] = data[i:i+1]
                if len(digits_found) == 5:
                    break
            if len(digits_found) == 5:
                break
    
    fig, axes = plt.subplots(5, 6, figsize=(15, 12))
    
    for row, (digit, image) in enumerate(digits_found.items()):
        # 原圖
        axes[row, 0].imshow(image[0, 0], cmap='gray')
        axes[row, 0].set_title(f'Digit {digit}')
        axes[row, 0].axis('off')
        
        # 特征圖
        with torch.no_grad():
            features = model.relu(model.conv1(image))[0]
        
        # 顯示前5個(gè)特征圖
        for col in range(1, 6):
            axes[row, col].imshow(features[col-1], cmap='viridis')
            axes[row, col].set_title(f'Feature {col}')
            axes[row, col].axis('off')
    
    plt.suptitle('How Different Digits Activate Different Features')
    plt.tight_layout()
    plt.show()

compare_digit_features(model, test_loader)
output_11_0.png
# 第13段:實(shí)驗(yàn)總結(jié)
print("=== MNIST CNN 實(shí)驗(yàn)總結(jié) ===")
print("\n1. Conv2D 形狀理解:")
print("   - Conv1: [16, 1, 5, 5] = [輸出通道, 輸入通道, 高, 寬]")
print("   - Conv2: [32, 16, 5, 5] = [輸出通道, 輸入通道, 高, 寬]")

print("\n2. 實(shí)驗(yàn)收獲:")
print("   - CNN 能自動(dòng)學(xué)習(xí)有用的特征檢測(cè)器")
print("   - 不同卷積核專(zhuān)門(mén)檢測(cè)不同的模式")
print("   - 特征圖顯示了網(wǎng)絡(luò)'看到'的內(nèi)容")
print("   - 深層網(wǎng)絡(luò)學(xué)習(xí)更復(fù)雜的特征")

print("\n3. 性能表現(xiàn):")
print(f"   - 僅用3個(gè)epoch就達(dá)到95%+準(zhǔn)確率")
print(f"   - 總參數(shù)量: {sum(p.numel() for p in model.parameters()):,}")
print("   - 比全連接網(wǎng)絡(luò)參數(shù)少得多")

print("\n4. 關(guān)鍵洞察:")
print("   - Conv2D 非常適合圖像識(shí)別任務(wù)")
print("   - 權(quán)重共享讓CNN參數(shù)效率很高")
print("   - 可視化幫助理解模型學(xué)到了什么")
print("   - 卷積核就是自動(dòng)學(xué)習(xí)的特征檢測(cè)器")

print("\n5. Conv2D 核心原理:")
print("   - 輸入: [batch, 1, 28, 28] 灰度圖像")
print("   - Conv1: 1通道→16通道,學(xué)習(xí)16種基礎(chǔ)特征")
print("   - Conv2: 16通道→32通道,組合成32種復(fù)雜特征")
print("   - 最終: 32個(gè)7x7特征圖→全連接層→10個(gè)類(lèi)別")

print("\n6. 為什么CNN這么強(qiáng):")
print("   - 局部連接: 只關(guān)注鄰近像素,符合圖像特性")
print("   - 權(quán)重共享: 同一特征在圖像任何位置都能檢測(cè)")
print("   - 平移不變性: 數(shù)字在圖像中移動(dòng)位置仍能識(shí)別")
print("   - 層次特征: 從簡(jiǎn)單邊緣到復(fù)雜形狀逐層抽象")
=== MNIST CNN 實(shí)驗(yàn)總結(jié) ===

1. Conv2D 形狀理解:
   - Conv1: [16, 1, 5, 5] = [輸出通道, 輸入通道, 高, 寬]
   - Conv2: [32, 16, 5, 5] = [輸出通道, 輸入通道, 高, 寬]

2. 實(shí)驗(yàn)收獲:
   - CNN 能自動(dòng)學(xué)習(xí)有用的特征檢測(cè)器
   - 不同卷積核專(zhuān)門(mén)檢測(cè)不同的模式
   - 特征圖顯示了網(wǎng)絡(luò)'看到'的內(nèi)容
   - 深層網(wǎng)絡(luò)學(xué)習(xí)更復(fù)雜的特征

3. 性能表現(xiàn):
   - 僅用3個(gè)epoch就達(dá)到95%+準(zhǔn)確率
   - 總參數(shù)量: 215,370
   - 比全連接網(wǎng)絡(luò)參數(shù)少得多

4. 關(guān)鍵洞察:
   - Conv2D 非常適合圖像識(shí)別任務(wù)
   - 權(quán)重共享讓CNN參數(shù)效率很高
   - 可視化幫助理解模型學(xué)到了什么
   - 卷積核就是自動(dòng)學(xué)習(xí)的特征檢測(cè)器

5. Conv2D 核心原理:
   - 輸入: [batch, 1, 28, 28] 灰度圖像
   - Conv1: 1通道→16通道,學(xué)習(xí)16種基礎(chǔ)特征
   - Conv2: 16通道→32通道,組合成32種復(fù)雜特征
   - 最終: 32個(gè)7x7特征圖→全連接層→10個(gè)類(lèi)別

6. 為什么CNN這么強(qiáng):
   - 局部連接: 只關(guān)注鄰近像素,符合圖像特性
   - 權(quán)重共享: 同一特征在圖像任何位置都能檢測(cè)
   - 平移不變性: 數(shù)字在圖像中移動(dòng)位置仍能識(shí)別
   - 層次特征: 從簡(jiǎn)單邊緣到復(fù)雜形狀逐層抽象

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容