nnUNet v2模型训练
nnUNet v2模型训练
1.数据集处理
nnUnet要求rgb-png格式的数据,故将原数据集由单通道堆叠成三通道的RGB图像
1 | import cv2 |
原先数据集的格式要求:
- train
- images
- labels
- test
- images
- labels
将数据集转化为nnUnet标准格式,改写nnUNet/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py
1 | if __name__ == "__main__": |
生成的数据集:
- 数据集名称
- imagesTr
- imagesTs
- labelsTr
- labelsTs
- dataset.json
添加环境变量:
1 | export nnUNet_raw="/root/autodl-tmp/nnUNet/dataset/nnUNet_raw" |
预处理数据集:
1 | nnUNetv2_plan_and_preprocess -d 150 --verify_dataset_integrity #150为任务id |
2.模型训练
开始训练:
nnUNetv2_train CONFIGURATION TRAINER_CLASS_NAME TASK_NAME_OR_ID FOLD (additional options)
CONFIGURATION:
模型架构,三种Unet: 2D U-Net, 3D U-Net and a U-Net Cascade(U-Net级联)。TASK_NAME_OR_ID:
任务全名TaskXXX_MYTASK或者是ID号FOLD:
第几折交叉验证,可选 [0, 1, 2, 3, 4],一共五折。
1 | nnUNetv2_train 666 2d 4 |
loss曲线:
3.模型测试
1 | nnUNetv2_predict -i “测试集路径” -o “输出路径” -chk checkpoint_best.pth -c 2d -f 4 -d 150 --save_probabilities |
将二值掩码转换为0或255
1 | from PIL import Image |
评价指标:
HD95: 5.20
Average Dice: 0.8144
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 丹青两幻!
评论