欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【语义分割系列:四】Unet 论文阅读翻译笔记 医学图像 pytorch实现

程序员文章站 2022-07-05 11:06:21
...

UNet

2015 MICCAI
Olaf Ronneberger, Philipp Fischer, Thomas Brox


U-Net: Convolutional Networks for Biomedical Image Segmentation

5 Minute Teaser Presentation of the U-net 5分钟预告片介绍

GitHub

时间顺序: old ——> new

❤ orobix | U-net | 视网膜血管分割

ZijunDeng | PyTorch FCN、U-Net、SegNet、PSPNet 、GCN、DUC, HDC | 火车

zhixuhao | Unet, using Keras | ISBI 2012 神经元结构

milesial | Pytorch U-Net 图像语义分割 | car

❤ LeeJunHyun | pytorch U-Net, R2U-Net, Attention U-Net, Attention R2U-Net | ISIC 2018 皮肤病变分析向黑色素瘤

qubvel | Unet Linknet FPN PSPNet | car

1、Introduce

基于FCN的一个语义分割网络,适合用来做医学图像的分割。

  • contracting path :similar to an encoder
    用于获取上下文信息(context)
  • expanding path :similar to a decoder
    扩张路径用于精确的定位(localization)
  • data augmentation with elastic deformations

优势:

  • 能够从极少图像端对端进行训练,并且在ISBI竞赛中,对于分割电子显微镜中的神经元结构的表现好于以前最好的方法(滑动窗口卷积网络)。
  • 运行速度快

Architectural

  • VALID padding not SAME padding(因为边界用了镜像处理)
  • matched lower and upper features after cropping lower feature(The cropping is necessary due to the loss of border pixels in every convolution:注意看图1中的虚线框,就是与右分支对应的位置去crop左分支)
  • weighted cross entropy loss to separate instances(加权交叉熵损失)
  • elastic deformations (弹性变形)

Mirroring

边界的镜像处理

【语义分割系列:四】Unet 论文阅读翻译笔记 医学图像 pytorch实现

医学图像很大,所以要切成一张张小的patch,切成patch的时候因为Unet网络结构原因,适合切成Overlap-tile(重叠平铺)的切图。

Overlap-tile strategy for seamless segmentation of arbitrary large images
用于任意大图像的无缝分割的重叠平铺策略

观察左图:
白框是要分割区域,但是在切图的时候要包含周围区域(周围overlap部分可以为分割区域边缘部分提供文理信息)
预测黄色区域中的分割,需要蓝色区域内的图像数据作为输入。

可以从右图看出:
黄框区域分割结果没有因为切成小patch而造成分割情况不好。

通过镜像推断缺少输入数据。

data augmentation

主要需要:移位、旋转 不变性;变形、灰度值变化 鲁棒性

  • 弹性变换

在3*3的网格上使用随机位移矢量产生平滑形变,其中位移来自于10像素标准差的高斯分布,且通过双三次插值法计算得出。在收缩路径的末尾的drop-out层进一步暗示了数据增强。

separation of touching objects of the same class

propose the use of a weighted loss, where the separating background labels between touching cells obtain a large weight in the loss function.

【语义分割系列:四】Unet 论文阅读翻译笔记 医学图像 pytorch实现

【语义分割系列:四】Unet 论文阅读翻译笔记 医学图像 pytorch实现

  • 预先计算每个ground truth segmentation的权值图
  • 补偿训练数据集中某个类的不同像素出现的频率
  • 使网络学习我们在touch cells之间引入的小的分离边界(如图3c和d所示)

The separation border is computed using morphological operations. The
weight map is then computed as

【语义分割系列:四】Unet 论文阅读翻译笔记 医学图像 pytorch实现

Wc :the weight map to balance the class frequencies
d1 :表示到最近单元格边界的距离
d2 :到第二个最近单元格边界的距离
W0 = 10 pixels
σ = 5 pixels

Back propagation

反卷积可以进行反向传播,所以整体Unet是可以反向传播的。

2、Network

【语义分割系列:四】Unet 论文阅读翻译笔记 医学图像 pytorch实现

上图为U-net网络结构图(以最低分别率为32*32为例)。
每个蓝色框对应一个多通道特征图(map),其中通道数在框顶标,x-y的大小位于框的左下角。
白色框表示复制的特征图。箭头表示不同的操作。

U-net网络由一个收缩路径(左边)和一个扩张路径(右边)组成。

【语义分割系列:四】Unet 论文阅读翻译笔记 医学图像 pytorch实现


  • contracting path 收缩路径遵循典型的卷积网络结构,其由两个重复的3*3卷积核(无填充卷积,unpadded convolution)组成,且均使用修正线性单元(rectified linear unit,ReLU)**函数和一个用于下采样(downsample)的步长为2的2*2 max pooling 操作,以及在每一个下采样的步骤中,特征通道数量都加倍。

  • expanding path 扩张路径中,每一步都包含对特征图进行上采样(upsample);然后用2*2的卷积核进行卷积运算(上卷积,up-convolution),用于减少一半的特征通道数量;接着级联收缩路径中相应的裁剪后的特征图;再用两个3*3的卷积核进行卷积运算,且均使用ReLU**函数。

  • 在最后一层,利用1*1的卷积核进行卷积运算,将每个64维的特征向量映射网络的输出层。

  • 网络有23个卷积层。

  • 权重初始化:高斯(0,sigma=sqrt(2/N))
    图像增强采用仿射变换

3、Train

  • favor large input tiles

  • reduce batch size to a single image

  • a high momentum (0.99)

  • energy function

    • computed by a pixel-wise soft-max over the final feature map
    • combined with the cross entropy loss function

网络分为四个主要部分:预处理,向下卷积,向上卷积,输出映射

Example 1

运行你的第一个U-net进行图像分割 ISBI2012 神经元 keras

easy code

  • data ISBI 30train/30test
  • 数据集很小,只有30张,数据增强很必要

图像扭曲论文

Example 2

bag分类,FCN改成Unet,调试中发现的问题

错误1:Sequential 里面有逗号

class upsamping(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(upsamping,self).__init__()
        self.up=nn.Sequential(
            nn.interpolate(scale_factor=2, mode='bilinear'),
            nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            )
    def forward(self,x):
        x=self.up(x)
        return x

错误二:nn.函数名写错

nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
self.maxpool=nn.MaxPool2d(kernel_size=2,stride=2)  # 2×2

错误三:忘记写super和init

def __init__(self,in_channels=3,n_class=2): 
        super(UNet,self).__init__()  
        self.maxpool=nn.MaxPool2d(kernel_size=2,stride=2)  # 2×2

错误四:UNet通道搞错了

忘记了cat之后通道数变了

错误五:Unet forward中output忘记return了

TypeError: sigmoid(): argument ‘input’ (position 1) must be Tensor, not NoneType

完整代码:

UNet.py

import torch
import torch.nn as nn
from torchvision import models
from torchvision.models.vgg import VGG
import torch.nn.functional as F
from torch.nn import init

class conv_block(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(conv_block,self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            )
    def forward(self,x):
        x=self.conv(x)
        return x

'''

nn.ReLU(inplace=True)
inplace=True意味着它将直接修改输入,而不分配任何额外的输出。它有时可以略微减少内存使用量,但可能并不总是有效的操作.


'''



class upsamping(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(upsamping,self).__init__()
        self.up=nn.Sequential(
            # nn.interpolate(scale_factor=2, mode='bilinear'),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            )
    def forward(self,x):
        x=self.up(x)
        return x


class UNet(nn.Module):

    def __init__(self,in_channels=3,n_class=2): 
        super(UNet,self).__init__()  
        self.maxpool=nn.MaxPool2d(kernel_size=2,stride=2)  # 2×2
        self.conv1=conv_block(in_channels,64)
        self.conv2=conv_block(64,128)
        self.conv3=conv_block(128,256)
        self.conv4=conv_block(256,512)
        self.conv5=conv_block(512,1024)
        self.upsamping5=upsamping(1024,512)
        self.upconv5=conv_block(1024,512)
        self.upsamping4=upsamping(512,256)
        self.upconv4=conv_block(512,256)
        self.upsamping3=upsamping(256,128)
        self.upconv3=conv_block(256,128)
        self.upsamping2=upsamping(128,64)
        self.upconv2=conv_block(128,64)
        self.upconv1=nn.Conv2d(64,n_class,kernel_size=1,stride=1,padding=0)   #和上面conv比没有 bias=True

    def forward(self,x):
        # contracting path 

        x1=self.conv1(x)  # [4, 64, 160, 160]

        x2=self.maxpool(x1)
        x2=self.conv2(x2)  # [4, 128, 80, 80]

        x3=self.maxpool(x2)
        x3=self.conv3(x3)  # [4, 256, 40, 40]

        x4=self.maxpool(x3)
        x4=self.conv4(x4)  # [4, 512, 20, 20]

        x5=self.maxpool(x4)
        x5=self.conv5(x5)  # [4, 1024, 10, 10]

        # expanding path 

        d5=self.upsamping5(x5)
        d5=torch.cat((x4,d5),dim=1)
        d5=self.upconv5(d5)  # [4, 512, 20, 20]

        d4=self.upsamping4(d5)
        d4=torch.cat((x3,d4),dim=1)
        d4=self.upconv4(d4)  # [4, 256, 40, 40]

        d3=self.upsamping3(d4)
        d3=torch.cat((x2,d3),dim=1)
        d3=self.upconv3(d3)  # [4, 128, 80, 80]

        d2=self.upsamping2(d3)
        d2=torch.cat((x1,d2),dim=1)
        d2=self.upconv2(d2)  # [4, 64, 160, 160]

        d1=self.upconv1(d2)   # [4, 2, 160, 160]
        return d1

train.py

# !/usr/bin/python
# -*- coding: utf-8 -*
'''
Train Bag with PyTorch.
FCN
U-Net
'''
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import visdom

from BagData import test_dataloader, train_dataloader
from FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNet
from UNET import UNet

# 命令行解析的库文件
import sys
import argparse  


epo_num=50


# vis = visdom.Visdom()   # 可视化控件

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
print('==> Building model..')

'''------------UNet--------------'''
unet_model = UNet()
unet_model = unet_model.to(device)

criterion = nn.BCELoss().to(device)
'''
BCELoss:二分类用的交叉熵,用的时候需要在该层前面加上 Sigmoid 函数。
CrossEntropyLoss:多分类用的交叉熵损失函数,用这个 loss 前面不需要加 Softmax 层。
'''
optimizer = optim.SGD(unet_model.parameters(), lr=1e-2, momentum=0.7)

all_train_iter_loss = []
all_test_iter_loss = []



# start timing
prev_time = datetime.now()
for epo in range(epo_num):  #  for each training  epoch
    print("\nEpoch: {}".format(epo))
    train_loss = 0
    unet_model.train()
    for index, (bag, bag_msk) in enumerate(train_dataloader): # 不停的 循环 这个 DataLoader 对象
        # if(index>0):
        #     sys.exit()

        # bag.shape is torch.Size([4, 3, 160, 160])
        # bag_msk.shape is torch.Size([4, 2, 160, 160])
        bag = bag.to(device)    # 
        bag_msk = bag_msk.to(device)
        optimizer.zero_grad()
        output = unet_model(bag)   # [4, 2, 160, 160]
        output = torch.sigmoid(output) #  ([4, 2, 160, 160])

        loss = criterion(output, bag_msk)
        loss.backward()
        optimizer.step()

        iter_loss = loss.item()  # 将tensor转换成python的scalars
        all_train_iter_loss.append(iter_loss)
        train_loss += iter_loss  # 计算总train loss


        '''
        .cpu() : Some operations on tensors cannot be performed on cuda tensors so you need to move them to cpu first.
        detach() : 截断反向传播的梯度流。
        '''
        output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
        #  转成热度图heatmap
        output_np = np.argmin(output_np, axis=1)    # 返回沿轴的最小值的索引   (4, 160, 160)

        '''
        min/max与np.argmin/np.argmax函数的功能不同:
        min/max  返回值,适合处理list等可迭代对象;
        np.argmin/np.argmax  返回最值所在的索引(下标)前者,适合处理numpy里的核心数据结构ndarray(多维数组)
        Returns:   index_array : 下标组成的数组。shape与输入数组a去掉axis的维度相同。
        '''
        bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160) 
        bag_msk_np = np.argmin(bag_msk_np, axis=1)   # (4, 160, 160)
        
        if np.mod(index, 15) == 0: # mod 返回两个元素相除后的余数
            print('epoch {}, {}/{},train loss is {}'.format(epo, index, len(train_dataloader), iter_loss))
        #     # vis.close()
        #     vis.images(output_np[:, None, :, :], win='train_pred', opts=dict(title='train prediction')) 
        #     vis.images(bag_msk_np[:, None, :, :], win='train_label', opts=dict(title='label'))
        #     vis.line(all_train_iter_loss, win='train_iter_loss',opts=dict(title='train iter loss'))
        # plt.subplot(1, 2, 1) 
        # plt.imshow(np.squeeze(bag_msk_np[0, ...]), 'gray')
        # plt.subplot(1, 2, 2) 
        # plt.imshow(np.squeeze(output_np[0, ...]), 'gray')
        # plt.pause(0.5)

    test_loss = 0
    unet_model.eval()
    with torch.no_grad():
        for index, (bag, bag_msk) in enumerate(test_dataloader):
            # if(index>0):
            #     sys.exit()
            bag = bag.to(device)
            bag_msk = bag_msk.to(device)

            optimizer.zero_grad()
            output = unet_model(bag)
            output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
            loss = criterion(output, bag_msk)
            iter_loss = loss.item()
            all_test_iter_loss.append(iter_loss)
            test_loss += iter_loss

            output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)  
            output_np = np.argmin(output_np, axis=1)
            bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160) 
            bag_msk_np = np.argmin(bag_msk_np, axis=1)
    
            # if np.mod(index, 15) == 0:
                # print(r'Testing... Open http://localhost:8097/ to see test result.')
            #     # vis.close()
            #     vis.images(output_np[:, None, :, :], win='test_pred', opts=dict(title='test prediction')) 
            #     vis.images(bag_msk_np[:, None, :, :], win='test_label', opts=dict(title='label'))
            #     vis.line(all_test_iter_loss, win='test_iter_loss', opts=dict(title='test iter loss'))
            
            plt.subplot(1, 2, 1) 
            plt.imshow(np.squeeze(bag_msk_np[0, ...]), 'gray')
            plt.subplot(1, 2, 2) 
            plt.imshow(np.squeeze(output_np[0, ...]), 'gray')
            plt.pause(0.5)   # 暂停半秒钟

    # len(train_dataloader)=135  540÷4  train_dataset total / batch size
    # len(test_dataloader) = 15   60÷4   test_dataset total / batch size

    cur_time = datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    # python divmod() 函数把除数和余数运算结果结合起来,返回一个包含商和余数的元组(a // b, a % b)。
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    prev_time = cur_time

    print('epoch train loss = %f, epoch test loss = %f, %s'
            %(train_loss/len(train_dataloader), test_loss/len(test_dataloader), time_str))
    

    if np.mod(epo, 5) == 0:
        torch.save(unet_model, 'checkpoints/unet_model_{}.pt'.format(epo))
        print('saveing checkpoints/unet_model{}.pt'.format(epo)) 

Example 3

LeeJunHyun | pytorch U-Net, R2U-Net, Attention U-Net, Attention R2U-Net | ISIC 2018 皮肤病变分析向黑色素瘤

黑色素瘤检测的皮肤病变分析 论文 2018

【语义分割系列:四】Unet 论文阅读翻译笔记 医学图像 pytorch实现

先 下载 数据集,根据 dataset.py 中路径放到响应位置 ISIC/dataset/ …
先 运行 dataset.py

source activate pytorch1.0
python dataset.py

【语义分割系列:四】Unet 论文阅读翻译笔记 医学图像 pytorch实现

数据准备完成

 python main.py

lr =0.01 (每50步乘0.1)
batchsize=1
epoch=250

【语义分割系列:四】Unet 论文阅读翻译笔记 医学图像 pytorch实现

4、Other

Why U-Net?

  • 多模态

医疗影像是具有多种模态的。以ISLES脑梗竞赛为例,其官方提供了CBF,MTT,CBV,TMAX,CTP等多种模态的数据。

设计网络去提取不同模态的特征feature。

参考论文:

Joint Sequence Learning and Cross-Modality Convolution for 3D Biomedical Segmentation(CVPR 2017)

Dense Multi-path U-Net for Ischemic Stroke Lesion Segmentation in Multiple Image Modalities.

  • 可解释性

医疗影像最终是辅助医生的临床诊断,所以网络告诉医生有没有病是远远不够的,医生还要进一步的想知道,病灶在哪一层?哪个位置?分割了吗?体积?结果是为什么?
比较常用的就是画activation map。看网络的哪些区域被**了

参考论文:

Learning Deep Features for Discriminative Localization(CVPR2016)

Deep Learning for Identifying Metastatic Breast Cancer 2016

  • 医学图像语义较为简单、结构较为固定,底层的特征其实很重要

U-net利用了底层的特征(同分辨率级联)改善上采样的信息不足。底层信息有助于提高精度,高层信息用来提取复杂特征。

  • 和FCN区别

    • 多尺度
      基于FCNs做改进。U-Net特征提取部分,每经过一个池化层就一个尺度,包括原图尺度一共有5个尺度。

    • 上采样部分
      每上采样一次,就和特征提取部分对应的通道数相同尺度融合,但是融合之前要将其crop。这里的融合是concat而不是FCN的element-wise。

    • 适合超大 图像分割,适合医学图像分割

design choices

没发现这个有什么用( ╯□╰ )

“DeepLab的四个对齐规则”:
(1)在所有卷积和pooling中使用奇数大小的内核。
(2)在所有卷积和pooling中使用SAME边界条件。
(3)使用双线性插值对特征映射进行上采样时,使用align_corners = True。
(4)使用height/width等于output_stride的倍数的输入加1(for example, when the CNN output stride is 8, use height or width equal to 8 * n + 1, for some n, e.g., image HxW set to 321x513)

R2U-Net

2018 CVPR
Md Zahangir Alom, Mahmudul Hasan, Chris Yakopcic, Tarek M. Taha, Vijayan K. Asari

Recurrent Residual Convolutional Neural Network based on U-Net (R2U-Net) for Medical Image Segmentation

UNet++

2018 CVPR
Zongwei Zhou, Md Mahfuzur Rahman Siddiquee, Nima Tajbakhsh, Jianming Liang

UNet++: A Nested U-Net Architecture for Medical Image Segmentation

研习U-Net++

Attention U-Net

2018 CVPR
Ozan Oktay, Jo Schlemper, Loic Le Folgoc, Matthew Lee

Attention U-Net: Learning Where to Look for the Pancreas

nnU-Net

2019 CVPR
F* Isensee, Jens Petersen, Simon A. A. Kohl, Paul F. Jäger, Klaus H. Maier-Hein

nnU-Net: Breaking the Spell on Successful Medical Image Segmentation

github上暂时是空的,作者还没更新代码