Segmentation des dents en 3D de la récupération des données au résultat final. Presque.
Avertissement
Cet article n'est pas éducatif dans aucun sens de ce terme et est purement informatif. L'auteur de l'article n'est pas responsable du temps passé à le lire.
A propos de l'auteur
Gentil - tout le monde, le nom est Andrey (27). J'essaierai d'ĂȘtre bref. Pourquoi programmer? Par Ă©ducation - baccalaurĂ©at en Ă©lectromĂ©canicien, je connais le mĂ©tier. J'ai travaillĂ© avec succĂšs pendant 2 ans en tant qu'ingĂ©nieur en Ă©nergie dans une sociĂ©tĂ© de forage.Au lieu d'une promotion, j'ai Ă©crit une dĂ©claration - j'ai Ă©puisĂ© le feu, mais cela ne s'est pas avĂ©rĂ© pour moi. J'aime crĂ©er, trouver des solutions Ă des problĂšmes complexes, avec un PC dans une Ă©treinte depuis des annĂ©es conscientes. Le choix est Ă©vident. Au dĂ©but (il y a six mois), j'ai sĂ©rieusement pensĂ© Ă m'inscrire Ă des cours de moi ou autre. J'ai lu les critiques, discutĂ© avec les participants et rĂ©alisĂ© qu'il n'y avait aucun problĂšme pour obtenir des informations. Alors j'ai trouvĂ© le site, J'ai eu une base Python lĂ -bas et j'ai commencĂ© mon voyage avec (maintenant j'Ă©tudie progressivement tout ce qui touche au ML). ImmĂ©diatement intĂ©ressĂ© par l'apprentissage automatique, CV en particulier. J'ai rencontrĂ© un problĂšme et me voici (pour moi, c'est une excellente façon d'apprendre).
1. Introduction
à la suite de plusieurs tentatives infructueuses, j'ai pris la décision d'utiliser 2 modÚles légers pour obtenir le résultat souhaité. Le premier segment toutes les dents en tant que catégorie [1, 0], et le second les divise en catégories [0, 8]. Mais commençons dans l'ordre.
2. Recherche et préparation des données
AprĂšs avoir passĂ© plus d'une soirĂ©e Ă chercher des donnĂ©es pour le travail, je suis arrivĂ© Ă la conclusion qu'une mĂąchoire libre de bonne qualitĂ© et de bon format (* .stl, * .nrrd, etc.) ne fonctionnera pas. Le meilleur que j'ai rencontrĂ© Ă©tait un Ă©chantillon de test de la tĂȘte d'un patient aprĂšs une chirurgie de la mĂąchoire dans 3D Slicer .

Ăvidemment, je n'ai pas besoin de toute la tĂȘte, j'ai donc coupĂ© la source dans le mĂȘme programme Ă la taille 163 * 112 * 120px (dans cet article {x * y * z = wdh} et 1px - 0,5 mm), ne laissant que le dents et parties maxillo-faciales associĂ©es.

, - . . , - "autothreshold" , , , , ( ).

12~14. , 4 . , .

, ( ) , . , , N- , random-crop .
import nrrd
import torch
import torchvision.transforms as tf
class DataBuilder:
def __init__(self,
data_path,
list_of_categories,
num_of_chunks: int = 0,
augmentation_coeff: int = 0,
num_of_classes: int = 0,
normalise: bool = False,
fit: bool = True,
data_format: int = 0,
save_data: bool = False
):
self.data_path = data_path
self.number_of_chunks = num_of_chunks
self.augmentation_coeff = augmentation_coeff
self.list_of_cats = list_of_categories
self.num_of_cls = num_of_classes
self.normalise = normalise
self.fit = fit
self.data_format = data_format
self.save_data = save_data
def forward(self):
data = self.get_data()
data = self.fit_data(data) if self.fit else data
data = self.pre_normalize(data) if self.normalise else data
data = self.data_augmentation(data, self.augmentation_coeff) if self.augmentation_coeff != 0 else data
data = self.new_chunks(data, self.number_of_chunks) if self.number_of_chunks != 0 else data
data = self.category_splitter(data, self.num_of_cls, self.list_of_cats) if self.num_of_cls != 0 else data
torch.save(data, self.data_path[-14:]+'.pt') if self.save_data else None
return torch.unsqueeze(data, 1)
def get_data(self):
if self.data_format == 0:
return torch.from_numpy(nrrd.read(self.data_path)[0])
elif self.data_format == 1:
return torch.load(self.data_path).cpu()
elif self.data_format == 2:
return torch.unsqueeze(self.data_path, 0).cpu()
else:
print('Available types are: "nrrd", "tensor" or "self.tensor(w/o load)"')
@staticmethod
def fit_data(some_data):
data = torch.movedim(some_data, (1, 0), (0, -1))
data_add_x = torch.nn.ZeroPad2d((5, 0, 0, 0))
data = data_add_x(data)
data = torch.movedim(data, -1, 0)
data_add_z = torch.nn.ZeroPad2d((0, 0, 8, 0))
return data_add_z(data)
@staticmethod
def pre_normalize(some_data):
min_d, max_d = torch.min(some_data), torch.max(some_data)
return (some_data - min_d) / (max_d - min_d)
@staticmethod
def data_augmentation(some_data, aug_n):
torch.manual_seed(17)
tr_data = []
for e in range(aug_n):
transform = tf.RandomRotation(degrees=(20*e, 20*e))
for image in some_data:
image = torch.unsqueeze(image, 0)
image = transform(image)
tr_data.append(image)
return tr_data
def new_chunks(self, some_data, n_ch):
data = torch.stack(some_data, 0) if self.augmentation_coeff != 0 else some_data
data = torch.squeeze(data, 1)
chunks = torch.chunk(data, n_ch, 0)
return torch.stack(chunks)
@staticmethod
def category_splitter(some_data, alpha, list_of_categories):
data, _ = torch.squeeze(some_data, 1).to(torch.int64), alpha
for i in list_of_categories:
data = torch.where(data < i, _, data)
_ += 1
return data - alpha
3D U-net. :
( ).
0 168*120*120 ( 163*112*120). * .
0...1 ( ~-2000...16000).
N- .
( 1, 1, 72, 120, 120).
28 (. ):
1-;
9 (8+) 2-.
Dataloader
import torch.utils.data as tud
class ToothDataset(tud.Dataset):
def __init__(self, images, masks):
self.images = images
self.masks = masks
def __len__(self): return len(self.images)
def __getitem__(self, index):
if self.masks is not None:
return self.images[index, :, :, :, :],\
self.masks[index, :, :, :, :]
else:
return self.images[index, :, :, :, :]
def get_loaders(images, masks,
batch_size: int = 1,
num_workers: int = 1,
pin_memory: bool = True):
train_ds = ToothDataset(images=images,
masks=masks)
data_loader = tud.DataLoader(train_ds,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory)
return data_loader
:
|
Semantic |
Instance |
Predictions |
Data |
(27*, 1, 56*, 120,120)[0...1] |
(27*, 1, 56*, 120,120) [0, 1] |
(1, 1, 168, 120, 120)[0...1] |
Masks |
(27*, 1, 56*, 120,120)[0, 1] |
(27*, 1, 56*, 120,120)[0, 8] |
- |
* , , - .
3.
- . U-Net. , .

, . - Adam, Dice-loss(implement), / 4, [64, 128, 256, 512] (, , - ). 60-80 epochs . Transfer learning .
model.summary()
model = UNet(dim=2, in_channels=1, out_channels=1, n_blocks=4, start_filters=64).to(device)
print(summary(model, (1, 168, 120)))
"""
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 168, 120] 640
ReLU-2 [-1, 64, 168, 120] 0
BatchNorm2d-3 [-1, 64, 168, 120] 128
Conv2d-4 [-1, 64, 168, 120] 36,928
ReLU-5 [-1, 64, 168, 120] 0
BatchNorm2d-6 [-1, 64, 168, 120] 128
MaxPool2d-7 [-1, 64, 84, 60] 0
DownBlock-8 [[-1, 64, 84, 60], [-1, 64, 168, 120]] 0
Conv2d-9 [-1, 128, 84, 60] 73,856
ReLU-10 [-1, 128, 84, 60] 0
BatchNorm2d-11 [-1, 128, 84, 60] 256
Conv2d-12 [-1, 128, 84, 60] 147,584
ReLU-13 [-1, 128, 84, 60] 0
BatchNorm2d-14 [-1, 128, 84, 60] 256
MaxPool2d-15 [-1, 128, 42, 30] 0
DownBlock-16 [[-1, 128, 42, 30], [-1, 128, 84, 60]] 0
Conv2d-17 [-1, 256, 42, 30] 295,168
ReLU-18 [-1, 256, 42, 30] 0
BatchNorm2d-19 [-1, 256, 42, 30] 512
Conv2d-20 [-1, 256, 42, 30] 590,080
ReLU-21 [-1, 256, 42, 30] 0
BatchNorm2d-22 [-1, 256, 42, 30] 512
MaxPool2d-23 [-1, 256, 21, 15] 0
DownBlock-24 [[-1, 256, 21, 15], [-1, 256, 42, 30]] 0
Conv2d-25 [-1, 512, 21, 15] 1,180,160
ReLU-26 [-1, 512, 21, 15] 0
BatchNorm2d-27 [-1, 512, 21, 15] 1,024
Conv2d-28 [-1, 512, 21, 15] 2,359,808
ReLU-29 [-1, 512, 21, 15] 0
BatchNorm2d-30 [-1, 512, 21, 15] 1,024
DownBlock-31 [[-1, 512, 21, 15], [-1, 512, 21, 15]] 0
ConvTranspose2d-32 [-1, 256, 42, 30] 524,544
ReLU-33 [-1, 256, 42, 30] 0
BatchNorm2d-34 [-1, 256, 42, 30] 512
Concatenate-35 [-1, 512, 42, 30] 0
Conv2d-36 [-1, 256, 42, 30] 1,179,904
ReLU-37 [-1, 256, 42, 30] 0
BatchNorm2d-38 [-1, 256, 42, 30] 512
Conv2d-39 [-1, 256, 42, 30] 590,080
ReLU-40 [-1, 256, 42, 30] 0
BatchNorm2d-41 [-1, 256, 42, 30] 512
UpBlock-42 [-1, 256, 42, 30] 0
ConvTranspose2d-43 [-1, 128, 84, 60] 131,200
ReLU-44 [-1, 128, 84, 60] 0
BatchNorm2d-45 [-1, 128, 84, 60] 256
Concatenate-46 [-1, 256, 84, 60] 0
Conv2d-47 [-1, 128, 84, 60] 295,040
ReLU-48 [-1, 128, 84, 60] 0
BatchNorm2d-49 [-1, 128, 84, 60] 256
Conv2d-50 [-1, 128, 84, 60] 147,584
ReLU-51 [-1, 128, 84, 60] 0
BatchNorm2d-52 [-1, 128, 84, 60] 256
UpBlock-53 [-1, 128, 84, 60] 0
ConvTranspose2d-54 [-1, 64, 168, 120] 32,832
ReLU-55 [-1, 64, 168, 120] 0
BatchNorm2d-56 [-1, 64, 168, 120] 128
Concatenate-57 [-1, 128, 168, 120] 0
Conv2d-58 [-1, 64, 168, 120] 73,792
ReLU-59 [-1, 64, 168, 120] 0
BatchNorm2d-60 [-1, 64, 168, 120] 128
Conv2d-61 [-1, 64, 168, 120] 36,928
ReLU-62 [-1, 64, 168, 120] 0
BatchNorm2d-63 [-1, 64, 168, 120] 128
UpBlock-64 [-1, 64, 168, 120] 0
Conv2d-65 [-1, 1, 168, 120] 65
================================================================
Total params: 7,702,721
Trainable params: 7,702,721
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.08
Forward/backward pass size (MB): 7434.08
Params size (MB): 29.38
Estimated Total Size (MB): 7463.54
"""
![.â1
2D U-Net, , [x, z] .â1
2D U-Net, , [x, z]](https://habrastorage.org/getpro/habr/upload_files/d24/21a/246/d2421a24672cfc2928bf98a22ebd3bc7.gif)
, - . , . numpy - *.stl 6. , :
![:
1. [x, y]. 2. [x, z]. 3. [y, z] :
1. [x, y]. 2. [x, z]. 3. [y, z]](https://habrastorage.org/getpro/habr/upload_files/75b/d08/ffb/75bd08ffbdeac83181abf79355378872.png)
100% , ? , .
, , , , , .
![.â2
2- 2D U-Net, , [y, z] .â2
2- 2D U-Net, , [y, z]](https://habrastorage.org/getpro/habr/upload_files/eef/335/e4e/eef335e4ebb6f3a3d23709cf81948e40.png)
, , :
![.â3
2- 2D U-Net, [y, z]
50% .â3
2- 2D U-Net, [y, z]
50%](https://habrastorage.org/getpro/habr/upload_files/3ce/fc5/3d5/3cefc53d5fee0fdf2693b95e6db91152.png)
3D . , (24*, 120, 120). ? - (~22. ). (1063gtx) .
24*
. :
(1512, 120, 120) - 63;
batch size (24, 120, 120) - , ;
(24) / ( 24/2/2/2=3 3*2*2*2=24, / 2 / 1);
, . .summary()
model.summary()
model = UNet(dim=3, in_channels=1, out_channels=1, n_blocks=4, start_filters=64).to(device)
print(summary(model, (1, 24, 120, 120)))
"""
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv3d-1 [-1, 64, 24, 120, 120] 1,792
ReLU-2 [-1, 64, 24, 120, 120] 0
BatchNorm3d-3 [-1, 64, 24, 120, 120] 128
Conv3d-4 [-1, 64, 24, 120, 120] 110,656
ReLU-5 [-1, 64, 24, 120, 120] 0
BatchNorm3d-6 [-1, 64, 24, 120, 120] 128
MaxPool3d-7 [-1, 64, 12, 60, 60] 0
DownBlock-8 [[-1, 64, 12, 60, 60], [-1, 64, 24, 120, 120]] 0
Conv3d-9 [-1, 128, 12, 60, 60] 221,312
ReLU-10 [-1, 128, 12, 60, 60] 0
BatchNorm3d-11 [-1, 128, 12, 60, 60] 256
Conv3d-12 [-1, 128, 12, 60, 60] 442,496
ReLU-13 [-1, 128, 12, 60, 60] 0
BatchNorm3d-14 [-1, 128, 12, 60, 60] 256
MaxPool3d-15 [-1, 128, 6, 30, 30] 0
DownBlock-16 [[-1, 128, 6, 30, 30], [-1, 128, 12, 60, 60]] 0
Conv3d-17 [-1, 256, 6, 30, 30] 884,992
ReLU-18 [-1, 256, 6, 30, 30] 0
BatchNorm3d-19 [-1, 256, 6, 30, 30] 512
Conv3d-20 [-1, 256, 6, 30, 30] 1,769,728
ReLU-21 [-1, 256, 6, 30, 30] 0
BatchNorm3d-22 [-1, 256, 6, 30, 30] 512
MaxPool3d-23 [-1, 256, 3, 15, 15] 0
DownBlock-24 [[-1, 256, 3, 15, 15], [-1, 256, 6, 30, 30]] 0
Conv3d-25 [-1, 512, 3, 15, 15] 3,539,456
ReLU-26 [-1, 512, 3, 15, 15] 0
BatchNorm3d-27 [-1, 512, 3, 15, 15] 1,024
Conv3d-28 [-1, 512, 3, 15, 15] 7,078,400
ReLU-29 [-1, 512, 3, 15, 15] 0
BatchNorm3d-30 [-1, 512, 3, 15, 15] 1,024
DownBlock-31 [[-1, 512, 3, 15, 15], [-1, 512, 3, 15, 15]] 0
ConvTranspose3d-32 [-1, 256, 6, 30, 30] 1,048,832
ReLU-33 [-1, 256, 6, 30, 30] 0
BatchNorm3d-34 [-1, 256, 6, 30, 30] 512
Concatenate-35 [-1, 512, 6, 30, 30] 0
Conv3d-36 [-1, 256, 6, 30, 30] 3,539,200
ReLU-37 [-1, 256, 6, 30, 30] 0
BatchNorm3d-38 [-1, 256, 6, 30, 30] 512
Conv3d-39 [-1, 256, 6, 30, 30] 1,769,728
ReLU-40 [-1, 256, 6, 30, 30] 0
BatchNorm3d-41 [-1, 256, 6, 30, 30] 512
UpBlock-42 [-1, 256, 6, 30, 30] 0
ConvTranspose3d-43 [-1, 128, 12, 60, 60] 262,272
ReLU-44 [-1, 128, 12, 60, 60] 0
BatchNorm3d-45 [-1, 128, 12, 60, 60] 256
Concatenate-46 [-1, 256, 12, 60, 60] 0
Conv3d-47 [-1, 128, 12, 60, 60] 884,864
ReLU-48 [-1, 128, 12, 60, 60] 0
BatchNorm3d-49 [-1, 128, 12, 60, 60] 256
Conv3d-50 [-1, 128, 12, 60, 60] 442,496
ReLU-51 [-1, 128, 12, 60, 60] 0
BatchNorm3d-52 [-1, 128, 12, 60, 60] 256
UpBlock-53 [-1, 128, 12, 60, 60] 0
ConvTranspose3d-54 [-1, 64, 24, 120, 120] 65,600
ReLU-55 [-1, 64, 24, 120, 120] 0
BatchNorm3d-56 [-1, 64, 24, 120, 120] 128
Concatenate-57 [-1, 128, 24, 120, 120] 0
Conv3d-58 [-1, 64, 24, 120, 120] 221,248
ReLU-59 [-1, 64, 24, 120, 120] 0
BatchNorm3d-60 [-1, 64, 24, 120, 120] 128
Conv3d-61 [-1, 64, 24, 120, 120] 110,656
ReLU-62 [-1, 64, 24, 120, 120] 0
BatchNorm3d-63 [-1, 64, 24, 120, 120] 128
UpBlock-64 [-1, 64, 24, 120, 120] 0
Conv3d-65 [-1, 1, 24, 120, 120] 65
================================================================
Total params: 22,400,321
Trainable params: 22,400,321
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.61
Forward/backward pass size (MB): 15974.12
Params size (MB): 85.45
Estimated Total Size (MB): 16060.18
----------------------------------------------------------------
"""
![.â4
3D U-Net, , [y, z],
*0,38 .â4
3D U-Net, , [y, z],
*0,38](https://habrastorage.org/getpro/habr/upload_files/618/cdd/83d/618cdd83db3c6855c2f916a11303c960.png)
~60% (25 epochs) , .
![.â5
3D U-Net, , [y, z],
65 epochs ~ 1,5 .â5
3D U-Net, , [y, z],
65 epochs ~ 1,5](https://habrastorage.org/getpro/habr/upload_files/7c7/e67/19a/7c7e6719a026fb11df22cd32eab03620.png)
. , (.â3) - :
![.â6
3D U-Net, , [x, z],
105 epochs ~ 2,1 .â6
3D U-Net, , [x, z],
105 epochs ~ 2,1](https://habrastorage.org/getpro/habr/upload_files/bc8/372/a07/bc8372a07b3aa13e874b0f9abfc4d21a.png)
"" . ~400 ( ~22) [18, 32, 64, 128] / 3. RSMProp. (1, 1, 72*, 120, 120). ?
model.summary()
model = UNet(dim=3, in_channels=1, out_channels=1, n_blocks=3, start_filters=18).to(device)
print(summary(model, (1, 1, 72, 120, 120)))
"""
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv3d-1 [-1, 18, 72, 120, 120] 504
ReLU-2 [-1, 18, 72, 120, 120] 0
BatchNorm3d-3 [-1, 18, 72, 120, 120] 36
Conv3d-4 [-1, 18, 72, 120, 120] 8,766
ReLU-5 [-1, 18, 72, 120, 120] 0
BatchNorm3d-6 [-1, 18, 72, 120, 120] 36
MaxPool3d-7 [-1, 18, 36, 60, 60] 0
DownBlock-8 [[-1, 18, 36, 60, 60], [-1, 18, 24, 120, 120]] 0
Conv3d-9 [-1, 36, 36, 60, 60] 17,532
ReLU-10 [-1, 36, 36, 60, 60] 0
BatchNorm3d-11 [-1, 36, 36, 60, 60] 72
Conv3d-12 [-1, 36, 36, 60, 60] 35,028
ReLU-13 [-1, 36, 36, 60, 60] 0
BatchNorm3d-14 [-1, 36, 36, 60, 60] 72
MaxPool3d-15 [-1, 36, 18, 30, 30] 0
DownBlock-16 [[-1, 36, 18, 30, 30], [-1, 36, 36, 60, 60]] 0
Conv3d-17 [-1, 72, 18, 30, 30] 70,056
ReLU-18 [-1, 72, 18, 30, 30] 0
BatchNorm3d-19 [-1, 72, 18, 30, 30] 144
Conv3d-20 [-1, 72, 18, 30, 30] 140,040
ReLU-21 [-1, 72, 18, 30, 30] 0
BatchNorm3d-22 [-1, 72, 18, 30, 30] 144
DownBlock-23 [[-1, 72, 18, 30, 30], [-1, 72, 18, 30, 30]] 0
ConvTranspose3d-24 [-1, 36, 36, 60, 60] 20,772
ReLU-25 [-1, 36, 36, 60, 60] 0
BatchNorm3d-26 [-1, 36, 36, 60, 60] 72
Concatenate-27 [-1, 72, 36, 60, 60] 0
Conv3d-28 [-1, 36, 36, 60, 60] 70,020
ReLU-29 [-1, 36, 36, 60, 60] 0
BatchNorm3d-30 [-1, 36, 36, 60, 60] 72
Conv3d-31 [-1, 36, 36, 60, 60] 35,028
ReLU-32 [-1, 36, 36, 60, 60] 0
BatchNorm3d-33 [-1, 36, 36, 60, 60] 72
UpBlock-34 [-1, 36, 36, 60, 60] 0
ConvTranspose3d-35 [-1, 18, 72, 120, 120] 5,202
ReLU-36 [-1, 18, 72, 120, 120] 0
BatchNorm3d-37 [-1, 18, 72, 120, 120] 36
Concatenate-38 [-1, 36, 72, 120, 120] 0
Conv3d-39 [-1, 18, 72, 120, 120] 17,514
ReLU-40 [-1, 18, 72, 120, 120] 0
BatchNorm3d-41 [-1, 18, 72, 120, 120] 36
Conv3d-42 [-1, 18, 72, 120, 120] 8,766
ReLU-43 [-1, 18, 72, 120, 120] 0
BatchNorm3d-44 [-1, 18, 72, 120, 120] 36
UpBlock-45 [-1, 18, 72, 120, 120] 0
Conv3d-46 [-1, 1, 72, 120, 120] 19
================================================================
Total params: 430,075
Trainable params: 430,075
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.32
Forward/backward pass size (MB): 5744.38
Params size (MB): 1.64
Estimated Total Size (MB): 5747.34
----------------------------------------------------------------
"""
72*
, (168, 120, 120), (72, 120, 120). , . , 2 , . 9 (1512, 120, 120) .. 9 , 21(batch size) (72, 120, 120). 72 , 24*().
![.â7
3D U-Net, , [x, z],
() (),
,
(65 epochs) ~ 14. .â7
3D U-Net, , [x, z],
() (),
,
(65 epochs) ~ 14.](https://habrastorage.org/getpro/habr/upload_files/9d5/497/b04/9d5497b04c3c3a80a2c6275e5148cb2c.png)
, ( "" ). , . semantic segmentation , .
3D ( ) (1512, 120, 120) --> 21*(1, 72, 120, 120), ~*(30, 30, 30) ( ). 2 : 3- , ( ); , .
, 1 epochs "" ~13, 2 (>80). 1 epochs. , .
. 8 + . loss function .
training loop
import torch
from tqdm import tqdm
from _loss_f import LossFunction
class TrainFunction:
def __init__(self,
data_loader,
device_for_training,
model_name,
model_name_pretrained,
model,
optimizer,
scale,
learning_rate: int = 1e-2,
num_epochs: int = 1,
transfer_learning: bool = False,
binary_loss_f: bool = True
):
self.data_loader = data_loader
self.device = device_for_training
self.model_name_pretrained = model_name_pretrained
self.semantic_binary = binary_loss_f
self.num_epochs = num_epochs
self.model_name = model_name
self.transfer = transfer_learning
self.optimizer = optimizer
self.learning_rate = learning_rate
self.model = model
self.scale = scale
def forward(self):
print('Running on the:', torch.cuda.get_device_name(self.device))
self.model.load_state_dict(torch.load(self.model_name_pretrained)) if self.transfer else None
optimizer = self.optimizer(self.model.parameters(), lr=self.learning_rate)
for epoch in range(self.num_epochs):
self.train_loop(self.data_loader, self.model, optimizer, self.scale, epoch)
torch.save(self.model.state_dict(), 'models/' + self.model_name+str(epoch+1)
+ '_epoch.pth') if (epoch + 1) % 10 == 0 else None
def train_loop(self, loader, model, optimizer, scales, i):
loop, epoch_loss = tqdm(loader), 0
loop.set_description('Epoch %i' % (self.num_epochs - i))
for batch_idx, (data, targets) in enumerate(loop):
data, targets = data.to(device=self.device, dtype=torch.float), \
targets.to(device=self.device, dtype=torch.long)
optimizer.zero_grad()
* *
with torch.cuda.amp.autocast():
predictions = model(data)
loss = LossFunction(predictions, targets,
device_for_training=self.device,
semantic_binary=self.semantic_binary
).forward()
scales.scale(loss).backward()
scales.step(optimizer)
scales.update()
epoch_loss += (1 - loss.item())*100
loop.set_postfix(loss=loss.item())
print('Epoch-acc', round(epoch_loss / (batch_idx+1), 2))
4.
Dice-loss , '' [0, 1]. , ( [0, 1]), ( "" "" ) Dice-loss , .
categorical_dice_loss
import torch
class LossFunction:
def __init__(self,
prediction,
target,
device_for_training,
semantic_binary: bool = True,
):
self.prediction = prediction
self.device = device_for_training
self.target = target
self.semantic_binary = semantic_binary
def forward(self):
if self.semantic_binary:
return self.dice_loss(self.prediction, self.target)
return self.categorical_dice_loss(self.prediction, self.target)
@staticmethod
def dice_loss(predictions, targets, alpha=1e-5):
intersection = 2. * (predictions * targets).sum()
denomination = (torch.square(predictions) + torch.square(targets)).sum()
dice_loss = 1 - torch.mean((intersection + alpha) / (denomination + alpha))
return dice_loss
def categorical_dice_loss(self, prediction, target):
pr, tr = self.prepare_for_multiclass_loss_f(prediction, target)
target_categories, losses = torch.unique(tr).tolist(), 0
for num_category in target_categories:
categorical_target = torch.where(tr == num_category, 1, 0)
categorical_prediction = pr[num_category][:][:][:]
losses += self.dice_loss(categorical_prediction, categorical_target).to(self.device)
return losses / len(target_categories)
@staticmethod
def prepare_for_multiclass_loss_f(prediction, target):
prediction_prepared = torch.squeeze(prediction, 0)
target_prepared = torch.squeeze(target, 0)
target_prepared = torch.squeeze(target_prepared, 0)
return prediction_prepared, target_prepared
, "categorical_dice_loss":
( );
, batch ;
"" "" , [0, 1] Dice-loss;
, batct. .
, , one-hot , ( ), , . , , , . (5).
5.
".. ". *.nrrd .
import nrrd
# numpy
read = nrrd.read(data_path)
data, meta_data = read[0], read[1]
print(data.shape, np.max(data), np.min(data), meta_data, sep="\n")
(163, 112, 120)
14982
-2254
OrderedDict([('type', 'short'), ('dimension', 3), ('space', 'left-posterior-superior'), ('sizes', array([163, 112, 120])), ('space directions', array([[-0.5, 0. , 0. ],
[ 0. , -0.5, 0. ],
[ 0. , 0. , 0.5]])), ('kinds', ['domain', 'domain', 'domain']), ('endian', 'little'), ('encoding', 'gzip'), ('space origin', array([131.57200623, 80.7661972 , 32.29940033]))])
- , ? , , , .

, 8 12 . ( ) - ( 3- ) . , , "" -1 , ..

- , . , . Skimage Stl.
from skimage.measure import marching_cubes
import nrrd
import numpy as np
from stl import mesh
path = 'some_path.nrrd'
data = nrrd.read(path)[0]
def three_d_creator(some_data):
vertices, faces, volume, _ = marching_cubes(some_data)
cube = mesh.Mesh(np.full(faces.shape[0], volume.shape[0], dtype=mesh.Mesh.dtype))
for i, f in enumerate(faces):
for j in range(3):
cube.vectors[i][j] = vertices[f[j]]
cube.save('name.stl')
return cube
stl = three_d_creator(datas)
, "" . , , Win 10 3D Builder - . "" 3D . " " .

v3do. , , .
npy stl
from vedo import Volume, show, write
prediction = 'some_data_path.npy'
def show_save(data, save=False):
data_multiclass = Volume(data, c='Set2', alpha=(0.1, 1), alphaUnit=0.87, mode=1)
data_multiclass.addScalarBar3D(nlabels=9)
show([(data_multiclass, "Multiclass teeth segmentation prediction")], bg='black', N=1, axes=1).close()
write(data_multiclass.isosurface(), 'some_name_.stl') if save else None
show_save(prediction, save=True)
.
. :
model.summary()
model = UNet(dim=3, in_channels=1, out_channels=9, n_blocks=3, start_filters=9).to(device)
print(summary(model, (1, 168*, 120, 120)))
"""
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv3d-1 [-1, 9, 168, 120, 120] 252
ReLU-2 [-1, 9, 168, 120, 120] 0
BatchNorm3d-3 [-1, 9, 168, 120, 120] 18
Conv3d-4 [-1, 9, 168, 120, 120] 2,196
ReLU-5 [-1, 9, 168, 120, 120] 0
BatchNorm3d-6 [-1, 9, 168, 120, 120] 18
MaxPool3d-7 [-1, 9, 84, 60, 60] 0
DownBlock-8 [[-1, 9, 84, 60, 60], [-1, 9, 168, 120, 120]] 0
Conv3d-9 [-1, 18, 84, 60, 60] 4,392
ReLU-10 [-1, 18, 84, 60, 60] 0
BatchNorm3d-11 [-1, 18, 84, 60, 60] 36
Conv3d-12 [-1, 18, 84, 60, 60] 8,766
ReLU-13 [-1, 18, 84, 60, 60] 0
BatchNorm3d-14 [-1, 18, 84, 60, 60] 36
MaxPool3d-15 [-1, 18, 42, 30, 30] 0
DownBlock-16 [[-1, 18, 18, 42, 30], [-1, 18, 84, 60, 60]] 0
Conv3d-17 [-1, 36, 42, 30, 30] 17,532
ReLU-18 [-1, 36, 42, 30, 30] 0
BatchNorm3d-19 [-1, 36, 42, 30, 30] 72
Conv3d-20 [-1, 36, 42, 30, 30] 35,028
ReLU-21 [-1, 36, 42, 30, 30] 0
BatchNorm3d-22 [-1, 36, 42, 30, 30] 72
DownBlock-23 [[-1, 36, 42, 30, 30], [-1, 36, 42, 30, 30]] 0
ConvTranspose3d-24 [-1, 18, 84, 60, 60] 5,202
ReLU-25 [-1, 18, 84, 60, 60] 0
BatchNorm3d-26 [-1, 18, 84, 60, 60] 36
Concatenate-27 [-1, 36, 84, 60, 60] 0
Conv3d-28 [-1, 18, 84, 60, 60] 17,514
ReLU-29 [-1, 18, 84, 60, 60] 0
BatchNorm3d-30 [-1, 18, 84, 60, 60] 36
Conv3d-31 [-1, 18, 84, 60, 60] 8,766
ReLU-32 [-1, 18, 84, 60, 60] 0
BatchNorm3d-33 [-1, 18, 84, 60, 60] 36
UpBlock-34 [-1, 18, 84, 60, 60] 0
ConvTranspose3d-35 [-1, 9, 168, 120, 120] 1,305
ReLU-36 [-1, 9, 168, 120, 120] 0
BatchNorm3d-37 [-1, 9, 168, 120, 120] 18
Concatenate-38 [-1, 18, 168, 120, 120] 0
Conv3d-39 [-1, 9, 168, 120, 120] 4,383
ReLU-40 [-1, 9, 168, 120, 120] 0
BatchNorm3d-41 [-1, 9, 168, 120, 120] 18
Conv3d-42 [-1, 9, 168, 120, 120] 2,196
ReLU-43 [-1, 9, 168, 120, 120] 0
BatchNorm3d-44 [-1, 9, 168, 120, 120] 18
UpBlock-45 [-1, 9, 168, 120, 120] 0
Conv3d-46 [-1, 9, 168, 120, 120] 90
================================================================
Total params: 108,036
Trainable params: 108,036
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.96
Forward/backward pass size (MB): 12170.30
Params size (MB): 0.41
Estimated Total Size (MB): 12174.66
----------------------------------------------------------------
"""
* ([9, 18, 36, 72]), - 9*(168, 120, 120)

, , . ? - "" 8- , . , 12 (GPU) .

6. After words
, , - . . , , 2 , . , ? , , 28 , , "" / ? U-net GCNN Pytorch - Pytorch3D? , , bounding box( 1 ). , , .
()

" "

Un merci spécial à ma femme, Alena, pour son soutien particulier lors de cette "plongée dans les ténÚbres".
Merci Ă tous pour votre attention. Les critiques et suggestions constructives, tant les corrections que les nouveaux projets, sont les bienvenues.