일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |
- 그래프
- Backpropagation
- git add
- forward
- 분할정복
- git branch
- Heap
- 병리
- 딥러닝
- Git
- cv
- 밑바닥부터 시작하는 딥러닝 1
- Merge
- 그래프이론
- BFS
- Pathology
- git commit
- conflict
- 알고리즘
- Python
- computer vision
- add
- WSI
- 백준
- 오차역전파
- 파이썬
- Branch
- Segment Anything
- DFS
- git merge
- Today
- Total
나만의 길
WSI-SAM: Multi-resolution Segment AnythingModel (SAM) for histopathology whole-slideimages 리뷰 본문
WSI-SAM: Multi-resolution Segment AnythingModel (SAM) for histopathology whole-slideimages 리뷰
yunway 2024. 4. 26. 18:03https://arxiv.org/abs/2403.09257
WSI-SAM: Multi-resolution Segment Anything Model (SAM) for histopathology whole-slide images
The Segment Anything Model (SAM) marks a significant advancement in segmentation models, offering robust zero-shot abilities and dynamic prompting. However, existing medical SAMs are not suitable for the multi-scale nature of whole-slide images (WSIs), res
arxiv.org
Overview
SAM은 segmentation에 중요한 기여를 함. 그러나 multi scale을 가지는 WSI에 대한 SAM은 여전히 제한적임. 따라서 본 논문에서는 multi resoulution을 고려한 WSI-SAM을 제시함.
특히, SAM의 encoder는 frozen시키고, HR, LR token과 dual mak decoder를 학습함. 중간 layer에 같은 WSI에 대한 multi scale을 더해줌으로써 기존 SAM보다 높은 성능향상을 이룸.
SAM을 이용한 최근 연구에서 Med-SAM, Medical SAM Adapter는 좋은 성능을 보여줌. 그러나, 저자들은 해당 모델이 WSI의 multi scale 특성을 잘 다루지 못할거라는 가설을 세움. WSI는 multi scale로 구성되어 있고, 해당 모델은 단일 scale에 대해서만 학습되었기 때문.
따라서 WSI의 특성을 고려하여 WSI에 특화된 SAM을 만들고자함.
Method
HR 토큰
- 1x256 크기의 벡터
- 고해상도(10x) 패치의 특징을 encoding
- 고해상도 패치의 세부 정보를 담고 있음
LR 토큰
- 1x256 크기의 벡터
- 저해상도(5x) 패치의 특징을 인코딩
- 저해상도 패치의 전체적인 context 정보를 담고 있음
즉, HR의 디테일한 특징과 LR의 context 정보를 모두 고려하기 때문에 보다 정확한 mask prediction이 가능함.
Token Aggregation은 HR, LR의 평균으로 계산. 구체적으로 어떻게 encoding을 하고, fusion module에 대한 설명은 아카이브라 작성되지 않음.
Architecture
- TinyViT backbone을 사용
- Encoder freeze, decoder → dual mask decoder 설계
- \(L=\lambda L_{high} + (1-\lambda )L_{low}\) 각각의 loss는 Dice와 CE의 합으로 이루어짐.
- \(L_{high}= Dice_{high}+CE_{high}\)
- \(L_{low}= Dice_{low}+CE_{low}\)
import torch
import torch.nn as nn
import torch.nn.functional as F
class WSI_SAM_Loss(nn.Module):
def __init__(self, lambda_weight=0.5, smooth=1e-6):
super(WSI_SAM_Loss, self).__init__()
self.lambda_weight = lambda_weight
self.smooth = smooth
def forward(self, hr_pred, lr_pred, hr_target, lr_target):
hr_pred = torch.sigmoid(hr_pred)
lr_pred = torch.sigmoid(lr_pred)
hr_pred = hr_pred.flatten()
lr_pred = lr_pred.flatten()
hr_target = hr_target.flatten()
lr_target = lr_target.flatten()
# Dice Loss
hr_intersection = (hr_pred * hr_target).sum()
hr_dice_loss = 1 - (2. * hr_intersection + self.smooth) / (hr_pred.sum() + hr_target.sum() + self.smooth)
lr_intersection = (lr_pred * lr_target).sum()
lr_dice_loss = 1 - (2. * lr_intersection + self.smooth) / (lr_pred.sum() + lr_target.sum() + self.smooth)
# Cross Entropy Loss
hr_ce_loss = F.binary_cross_entropy(hr_pred, hr_target, reduction='mean')
lr_ce_loss = F.binary_cross_entropy(lr_pred, lr_target, reduction='mean')
# Combined Loss
hr_loss = hr_dice_loss + hr_ce_loss
lr_loss = lr_dice_loss + lr_ce_loss
total_loss = self.lambda_weight * hr_loss + (1 - self.lambda_weight) * lr_loss
return total_loss
Experiments
Train
- CATCH dataset으로 training
- prompt
- random point
- bbox
- mask → GT를 Gaussian noise를 섞어서 재구축(boundary 보정)
- mini batch 1, learning rate 0.001
- HW 제한으로 인해 10x, 5x WSI에 대해서만 진행
Test
- CAMELYON16(random 50개), DCIS(47개)로 진행
- SAM, MedSAM에 대해서 비교 실험 진행
- bbox prompt
- GT bbox에 noise를 줘서 비교
- nnU-Net으로 생성한 mask와 비교
- GT bbox에 noise를 줘서 비교
- Random point
Conclusion
아직 전체 논문이 공개되지 않아 Architecture의 정확한 구조는 파악하기 어려움. 그러나 본 논문에서 제시한 일부는 참고하면 좋을 것 같음.
'Paper Review > Pathology' 카테고리의 다른 글
Segment anything in medical images 리뷰 (0) | 2024.03.24 |
---|---|
Weakly supervised multiple instance learning histopathological tumor segmentation 리뷰 (0) | 2024.03.15 |
Deep Learning for Whole Slide Image Analysis : An Overview 리뷰 (2) | 2024.03.08 |