nnUNet v2模型训练

1.数据集处理

nnUnet要求rgb-png格式的数据,故将原数据集由单通道堆叠成三通道的RGB图像

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import cv2
import numpy as np
import os
def gray_to_rgb(image_gray):
# 创建一个全零的三通道图像
height, width = image_gray.shape
image_rgb = np.zeros((height, width, 3), dtype=np.uint8)
# 将灰度图像的值复制到红通道
image_rgb[:, :, 2] = image_gray
image_rgb[:, :, 1] = image_gray
image_rgb[:, :, 0] = image_gray
return image_rgb
# 设置目标文件夹路径
# 包含灰度PNG图像的文件夹路径
output_folder = # 用于保存RGB图像的文件夹路径
root_folder=
# 创建输出文件夹(如果不存在)
if not os.path.exists(output_folder):
os.makedirs(output_folder)
for root, dirs, files in os.walk(root_folder):
# 遍历目标文件夹中的所有图像文件
if root == root_folder:
for dir_name in dirs:
input_folder = os.path.join(root, dir_name, dir_name + "_label")
print(input_folder)
for filename in os.listdir(input_folder):
if filename.endswith('.png'):
# 构造图像文件的完整输入路径
input_image_path = os.path.join(input_folder, filename)
# 读取灰度图像
image_gray = cv2.imread(input_image_path, cv2.IMREAD_GRAYSCALE)
if image_gray is not None:
# 转换为RGB图像
image_rgb = gray_to_rgb(image_gray)
# 构造保存的RGB图像文件名(输出路径)
output_image_path = os.path.join(output_folder, filename.replace('.png', '_rgb.png'))
# 保存RGB图像到指定输出路径
cv2.imwrite(output_image_path, image_rgb)
print("Conversion completed.")

原先数据集的格式要求:

  • train
    • images
    • labels
  • test
    • images
    • labels

将数据集转化为nnUnet标准格式,改写nnUNet/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
if __name__ == "__main__":
# extracted archive from https://www.kaggle.com/datasets/insaff/massachusetts-roads-dataset?resource=download
source = '/root/autodl-tmp/nnUNet/Coronary'
print(source)
dataset_name = 'Dataset150_Segmentation'
nnUNet_raw = '/root/autodl-tmp/nnUNet/dataset/nnUNet_raw'
imagestr = join(nnUNet_raw, dataset_name, 'imagesTr')
imagests = join(nnUNet_raw, dataset_name, 'imagesTs')
labelstr = join(nnUNet_raw, dataset_name, 'labelsTr')
labelsts = join(nnUNet_raw, dataset_name, 'labelsTs')

maybe_mkdir_p(imagestr)
maybe_mkdir_p(imagests)
maybe_mkdir_p(labelstr)
maybe_mkdir_p(labelsts)

train_source = join(source, 'train')
test_source = join(source, 'test')

with multiprocessing.get_context("spawn").Pool(8) as p:

# not all training images have a segmentation
valid_ids = subfiles(join(train_source, 'labels'), join=False, suffix='png')
num_train = len(valid_ids)
r = []
for v in valid_ids:
r.append(
p.starmap_async(
load_and_covnert_case,
((
join(train_source, 'images', v),
join(train_source, 'labels', v),
join(imagestr, v[:-4] + '_0000.png'),
join(labelstr, v),
50
),)
)
)

# test set
valid_ids = subfiles(join(test_source, 'labels'), join=False, suffix='png')
for v in valid_ids:
r.append(
p.starmap_async(
load_and_covnert_case,
((
join(test_source, 'images', v),
join(test_source, 'labels', v),
join(imagests, v[:-4] + '_0000.png'),
join(labelsts, v),
50
),)
)
)
_ = [i.get() for i in r]

generate_dataset_json(join(nnUNet_raw, dataset_name), {0: 'R', 1: 'G', 2: 'B'}, {'background': 0, 'coronary': 1},
num_train, '.png', dataset_name=dataset_name)

生成的数据集:

  • 数据集名称
    • imagesTr
    • imagesTs
    • labelsTr
    • labelsTs
    • dataset.json

添加环境变量:

1
2
3
export nnUNet_raw="/root/autodl-tmp/nnUNet/dataset/nnUNet_raw"
export nnUNet_preprocessed="/root/autodl-tmp/nnUNet/dataset/nnUNet_preprocessed"
export nnUNet_results="/root/autodl-tmp/nnUNet/dataset/nnUnet_results"

预处理数据集:

1
nnUNetv2_plan_and_preprocess -d 150 --verify_dataset_integrity #150为任务id

2.模型训练

开始训练:

nnUNetv2_train CONFIGURATION TRAINER_CLASS_NAME TASK_NAME_OR_ID FOLD (additional options)

  • CONFIGURATION: 模型架构,三种Unet: 2D U-Net, 3D U-Net and a U-Net Cascade(U-Net级联)。
  • TASK_NAME_OR_ID: 任务全名TaskXXX_MYTASK或者是ID号
  • FOLD: 第几折交叉验证,可选 [0, 1, 2, 3, 4],一共五折。
1
nnUNetv2_train 666 2d 4

loss曲线:

image-20231106164804089

3.模型测试

1
nnUNetv2_predict -i “测试集路径” -o “输出路径” -chk checkpoint_best.pth -c 2d -f 4 -d 150 --save_probabilities

将二值掩码转换为0或255

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from PIL import Image
import os

def process_images_in_folder(input_folder, output_folder):
for root, dirs, files in os.walk(input_folder):
for file in files:
if file.endswith(".png"):
input_image_path = os.path.join(root, file)
output_image_path = os.path.join(output_folder, file)

# 打开输入图像
image = Image.open(input_image_path)

# 将像素值为1的通道变为255
image = image.convert("RGB")
data = image.getdata()
new_data = [(r, g, b) if r != 1 and g != 1 and b != 1 else (255, 255, 255) for (r, g, b) in data]
image.putdata(new_data)

# 保存修改后的图像
image.save(output_image_path)

# 指定输入文件夹和输出文件夹的路径
input_folder_path = ""
output_folder_path = ""

if not os.path.exists(output_folder_path):
os.makedirs(output_folder_path)

process_images_in_folder(input_folder_path, output_folder_path)

image-20231106200912078

评价指标:

HD95: 5.20

Average Dice: 0.8144