卷積可視化:https://animatedai.github.io/

image.png
# 第1段:導(dǎo)入必要的庫(kù)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
print("PyTorch version:", torch.__version__)
print("Device:", "CUDA" if torch.cuda.is_available() else "CPU")
PyTorch version: 2.7.1
Device: CPU
# 第2段:加載MNIST數(shù)據(jù)集
def load_data():
transform = transforms.ToTensor()
train_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
download=True,
transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data',
train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=64,
shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=64,
shuffle=False)
return train_loader, test_loader
# 加載數(shù)據(jù)
train_loader, test_loader = load_data()
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")
Training samples: 60000
Test samples: 10000
# 第3段:查看數(shù)據(jù)樣本
def show_sample_data(dataloader, num_samples=8):
data, labels = next(iter(dataloader))
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i in range(num_samples):
ax = axes[i//4, i%4]
ax.imshow(data[i][0], cmap='gray')
ax.set_title(f'Label: {labels[i].item()}')
ax.axis('off')
plt.suptitle('Sample MNIST Images')
plt.tight_layout()
plt.show()
print(f"Image shape: {data[0].shape}") # [1, 28, 28]
print(f"Batch shape: {data.shape}") # [64, 1, 28, 28]
show_sample_data(train_loader)

output_2_0.png
Image shape: torch.Size([1, 28, 28])
Batch shape: torch.Size([64, 1, 28, 28])
# 第4段:定義CNN模型
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
# 第一層卷積: 1→16通道,5x5卷積核
self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=2)
self.pool1 = nn.MaxPool2d(2, 2) # 28x28 → 14x14
# 第二層卷積: 16→32通道,5x5卷積核
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=2)
self.pool2 = nn.MaxPool2d(2, 2) # 14x14 → 7x7
# 全連接層
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10) # 10個(gè)數(shù)字類(lèi)別
# 激活函數(shù)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.2)
def forward(self, x):
# 卷積層1
x = self.conv1(x) # [64, 1, 28, 28] → [64, 16, 28, 28]
x = self.relu(x)
x = self.pool1(x) # [64, 16, 28, 28] → [64, 16, 14, 14]
# 卷積層2
x = self.conv2(x) # [64, 16, 14, 14] → [64, 32, 14, 14]
x = self.relu(x)
x = self.pool2(x) # [64, 32, 14, 14] → [64, 32, 7, 7]
# 展平
x = x.view(x.size(0), -1) # [64, 32, 7, 7] → [64, 1568]
# 全連接層
x = self.fc1(x) # [64, 1568] → [64, 128]
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x) # [64, 128] → [64, 10]
return x
# 創(chuàng)建模型并查看參數(shù)
model = SimpleCNN()
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
SimpleCNN(
(conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=1568, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
(relu): ReLU()
(dropout): Dropout(p=0.2, inplace=False)
)
Total parameters: 215,370
# 第5段:分析Conv2D權(quán)重形狀
print("=== Conv2D Weight Shape Analysis ===")
print(f"Conv1 weight shape: {model.conv1.weight.shape}")
print("Meaning: [16 output channels, 1 input channel, 5 height, 5 width]")
print(f"\nConv2 weight shape: {model.conv2.weight.shape}")
print("Meaning: [32 output channels, 16 input channels, 5 height, 5 width]")
print("\n=== Data Flow Through Model ===")
sample_input = torch.randn(1, 1, 28, 28)
print(f"Input shape: {sample_input.shape}")
with torch.no_grad():
x = model.conv1(sample_input)
print(f"After Conv1: {x.shape}")
x = model.relu(x)
x = model.pool1(x)
print(f"After Pool1: {x.shape}")
x = model.conv2(x)
print(f"After Conv2: {x.shape}")
x = model.relu(x)
x = model.pool2(x)
print(f"After Pool2: {x.shape}")
x = x.view(x.size(0), -1)
print(f"After Flatten: {x.shape}")
x = model.fc1(x)
print(f"After FC1: {x.shape}")
x = model.fc2(x)
print(f"Final Output: {x.shape}")
=== Conv2D Weight Shape Analysis ===
Conv1 weight shape: torch.Size([16, 1, 5, 5])
Meaning: [16 output channels, 1 input channel, 5 height, 5 width]
Conv2 weight shape: torch.Size([32, 16, 5, 5])
Meaning: [32 output channels, 16 input channels, 5 height, 5 width]
=== Data Flow Through Model ===
Input shape: torch.Size([1, 1, 28, 28])
After Conv1: torch.Size([1, 16, 28, 28])
After Pool1: torch.Size([1, 16, 14, 14])
After Conv2: torch.Size([1, 32, 14, 14])
After Pool2: torch.Size([1, 32, 7, 7])
After Flatten: torch.Size([1, 1568])
After FC1: torch.Size([1, 128])
Final Output: torch.Size([1, 10])
# 第6段:定義訓(xùn)練函數(shù)
def train_model(model, train_loader, epochs=3):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
train_losses = []
for epoch in range(epochs):
epoch_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# 計(jì)算準(zhǔn)確率
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
# 每200個(gè)batch打印一次
if batch_idx % 200 == 0:
print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')
avg_loss = epoch_loss / len(train_loader)
accuracy = 100 * correct / total
train_losses.append(avg_loss)
print(f'Epoch {epoch+1} Complete: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%')
return train_losses
# 第7段:訓(xùn)練模型
print("Starting training...")
train_losses = train_model(model, train_loader, epochs=3)
# 繪制訓(xùn)練損失
plt.figure(figsize=(8, 5))
plt.plot(train_losses, marker='o')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()
Starting training...
Epoch 1, Batch 0, Loss: 2.3039
Epoch 1, Batch 200, Loss: 0.3510
Epoch 1, Batch 400, Loss: 0.1626
Epoch 1, Batch 600, Loss: 0.0927
Epoch 1, Batch 800, Loss: 0.0329
Epoch 1 Complete: Loss=0.2248, Accuracy=93.09%
Epoch 2, Batch 0, Loss: 0.0343
Epoch 2, Batch 200, Loss: 0.0524
Epoch 2, Batch 400, Loss: 0.0231
Epoch 2, Batch 600, Loss: 0.0102
Epoch 2, Batch 800, Loss: 0.0253
Epoch 2 Complete: Loss=0.0660, Accuracy=97.89%
Epoch 3, Batch 0, Loss: 0.0603
Epoch 3, Batch 200, Loss: 0.0493
Epoch 3, Batch 400, Loss: 0.0388
Epoch 3, Batch 600, Loss: 0.0033
Epoch 3, Batch 800, Loss: 0.1216
Epoch 3 Complete: Loss=0.0472, Accuracy=98.53%

output_6_1.png
# 第8段:測(cè)試模型
def test_model(model, test_loader):
model.eval()
correct = 0
total = 0
class_correct = [0] * 10
class_total = [0] * 10
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
# 每個(gè)數(shù)字的準(zhǔn)確率
for i in range(target.size(0)):
label = target[i]
class_correct[label] += (predicted[i] == label).item()
class_total[label] += 1
print(f'Overall Test Accuracy: {100 * correct / total:.2f}%')
print('\nPer-digit Recognition Accuracy:')
for i in range(10):
if class_total[i] > 0:
acc = 100 * class_correct[i] / class_total[i]
print(f'Digit {i}: {acc:.1f}% ({class_correct[i]}/{class_total[i]})')
test_model(model, test_loader)
Overall Test Accuracy: 98.94%
Per-digit Recognition Accuracy:
Digit 0: 99.4% (974/980)
Digit 1: 99.7% (1132/1135)
Digit 2: 99.6% (1028/1032)
Digit 3: 99.2% (1002/1010)
Digit 4: 97.7% (959/982)
Digit 5: 98.7% (880/892)
Digit 6: 99.1% (949/958)
Digit 7: 98.4% (1012/1028)
Digit 8: 98.3% (957/974)
Digit 9: 99.2% (1001/1009)
# 第9段:可視化預(yù)測(cè)結(jié)果
def visualize_predictions(model, test_loader, num_samples=16):
model.eval()
data, target = next(iter(test_loader))
with torch.no_grad():
output = model(data)
_, predicted = torch.max(output, 1)
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i in range(num_samples):
ax = axes[i//4, i%4]
ax.imshow(data[i][0], cmap='gray')
true_label = target[i].item()
pred_label = predicted[i].item()
color = 'green' if true_label == pred_label else 'red'
ax.set_title(f'True: {true_label}, Pred: {pred_label}', color=color)
ax.axis('off')
plt.suptitle('Prediction Results (Green=Correct, Red=Wrong)')
plt.tight_layout()
plt.show()
visualize_predictions(model, test_loader)

output_8_0.png
# 第10段:分析學(xué)習(xí)到的卷積核
def analyze_conv_kernels(model):
conv1_weight = model.conv1.weight.data # [16, 1, 5, 5]
print(f"First layer kernels shape: {conv1_weight.shape}")
print("These are the learned feature detectors!")
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(16):
ax = axes[i//8, i%8]
kernel = conv1_weight[i, 0].numpy()
im = ax.imshow(kernel, cmap='RdBu')
ax.set_title(f'Kernel {i+1}')
ax.axis('off')
plt.suptitle('Learned Conv Kernels (Feature Detectors)')
plt.tight_layout()
plt.show()
analyze_conv_kernels(model)
First layer kernels shape: torch.Size([16, 1, 5, 5])
These are the learned feature detectors!

output_9_1.png
# 第11段:可視化特征圖
def visualize_feature_maps(model, test_loader):
model.eval()
data, _ = next(iter(test_loader))
sample_image = data[0:1] # 取第一張圖
with torch.no_grad():
# 第一層卷積后的特征圖
x = model.conv1(sample_image)
x = model.relu(x)
feature_maps = x[0] # [16, 28, 28]
# 顯示原圖和16個(gè)特征圖
fig, axes = plt.subplots(3, 6, figsize=(15, 8))
# 原圖
axes[0, 0].imshow(sample_image[0, 0], cmap='gray')
axes[0, 0].set_title('Original Image')
axes[0, 0].axis('off')
# 特征圖
for i in range(16):
if i < 5: # 第一行剩余位置
row, col = 0, i+1
elif i < 11: # 第二行
row, col = 1, i-5
else: # 第三行
row, col = 2, i-11
axes[row, col].imshow(feature_maps[i], cmap='viridis')
axes[row, col].set_title(f'Feature {i+1}')
axes[row, col].axis('off')
plt.suptitle('Feature Maps after First Conv Layer')
plt.tight_layout()
plt.show()
visualize_feature_maps(model, test_loader)

output_10_0.png
# 第12段:對(duì)比不同數(shù)字的特征激活
def compare_digit_features(model, test_loader):
model.eval()
# 找到不同數(shù)字的樣本
digits_found = {}
with torch.no_grad():
for data, target in test_loader:
for i, label in enumerate(target):
digit = label.item()
if digit not in digits_found and len(digits_found) < 5:
digits_found[digit] = data[i:i+1]
if len(digits_found) == 5:
break
if len(digits_found) == 5:
break
fig, axes = plt.subplots(5, 6, figsize=(15, 12))
for row, (digit, image) in enumerate(digits_found.items()):
# 原圖
axes[row, 0].imshow(image[0, 0], cmap='gray')
axes[row, 0].set_title(f'Digit {digit}')
axes[row, 0].axis('off')
# 特征圖
with torch.no_grad():
features = model.relu(model.conv1(image))[0]
# 顯示前5個(gè)特征圖
for col in range(1, 6):
axes[row, col].imshow(features[col-1], cmap='viridis')
axes[row, col].set_title(f'Feature {col}')
axes[row, col].axis('off')
plt.suptitle('How Different Digits Activate Different Features')
plt.tight_layout()
plt.show()
compare_digit_features(model, test_loader)

output_11_0.png
# 第13段:實(shí)驗(yàn)總結(jié)
print("=== MNIST CNN 實(shí)驗(yàn)總結(jié) ===")
print("\n1. Conv2D 形狀理解:")
print(" - Conv1: [16, 1, 5, 5] = [輸出通道, 輸入通道, 高, 寬]")
print(" - Conv2: [32, 16, 5, 5] = [輸出通道, 輸入通道, 高, 寬]")
print("\n2. 實(shí)驗(yàn)收獲:")
print(" - CNN 能自動(dòng)學(xué)習(xí)有用的特征檢測(cè)器")
print(" - 不同卷積核專(zhuān)門(mén)檢測(cè)不同的模式")
print(" - 特征圖顯示了網(wǎng)絡(luò)'看到'的內(nèi)容")
print(" - 深層網(wǎng)絡(luò)學(xué)習(xí)更復(fù)雜的特征")
print("\n3. 性能表現(xiàn):")
print(f" - 僅用3個(gè)epoch就達(dá)到95%+準(zhǔn)確率")
print(f" - 總參數(shù)量: {sum(p.numel() for p in model.parameters()):,}")
print(" - 比全連接網(wǎng)絡(luò)參數(shù)少得多")
print("\n4. 關(guān)鍵洞察:")
print(" - Conv2D 非常適合圖像識(shí)別任務(wù)")
print(" - 權(quán)重共享讓CNN參數(shù)效率很高")
print(" - 可視化幫助理解模型學(xué)到了什么")
print(" - 卷積核就是自動(dòng)學(xué)習(xí)的特征檢測(cè)器")
print("\n5. Conv2D 核心原理:")
print(" - 輸入: [batch, 1, 28, 28] 灰度圖像")
print(" - Conv1: 1通道→16通道,學(xué)習(xí)16種基礎(chǔ)特征")
print(" - Conv2: 16通道→32通道,組合成32種復(fù)雜特征")
print(" - 最終: 32個(gè)7x7特征圖→全連接層→10個(gè)類(lèi)別")
print("\n6. 為什么CNN這么強(qiáng):")
print(" - 局部連接: 只關(guān)注鄰近像素,符合圖像特性")
print(" - 權(quán)重共享: 同一特征在圖像任何位置都能檢測(cè)")
print(" - 平移不變性: 數(shù)字在圖像中移動(dòng)位置仍能識(shí)別")
print(" - 層次特征: 從簡(jiǎn)單邊緣到復(fù)雜形狀逐層抽象")
=== MNIST CNN 實(shí)驗(yàn)總結(jié) ===
1. Conv2D 形狀理解:
- Conv1: [16, 1, 5, 5] = [輸出通道, 輸入通道, 高, 寬]
- Conv2: [32, 16, 5, 5] = [輸出通道, 輸入通道, 高, 寬]
2. 實(shí)驗(yàn)收獲:
- CNN 能自動(dòng)學(xué)習(xí)有用的特征檢測(cè)器
- 不同卷積核專(zhuān)門(mén)檢測(cè)不同的模式
- 特征圖顯示了網(wǎng)絡(luò)'看到'的內(nèi)容
- 深層網(wǎng)絡(luò)學(xué)習(xí)更復(fù)雜的特征
3. 性能表現(xiàn):
- 僅用3個(gè)epoch就達(dá)到95%+準(zhǔn)確率
- 總參數(shù)量: 215,370
- 比全連接網(wǎng)絡(luò)參數(shù)少得多
4. 關(guān)鍵洞察:
- Conv2D 非常適合圖像識(shí)別任務(wù)
- 權(quán)重共享讓CNN參數(shù)效率很高
- 可視化幫助理解模型學(xué)到了什么
- 卷積核就是自動(dòng)學(xué)習(xí)的特征檢測(cè)器
5. Conv2D 核心原理:
- 輸入: [batch, 1, 28, 28] 灰度圖像
- Conv1: 1通道→16通道,學(xué)習(xí)16種基礎(chǔ)特征
- Conv2: 16通道→32通道,組合成32種復(fù)雜特征
- 最終: 32個(gè)7x7特征圖→全連接層→10個(gè)類(lèi)別
6. 為什么CNN這么強(qiáng):
- 局部連接: 只關(guān)注鄰近像素,符合圖像特性
- 權(quán)重共享: 同一特征在圖像任何位置都能檢測(cè)
- 平移不變性: 數(shù)字在圖像中移動(dòng)位置仍能識(shí)別
- 層次特征: 從簡(jiǎn)單邊緣到復(fù)雜形狀逐層抽象