Segmentation d'instance de dents 3D. Dans le noir, mais pas seul

Segmentation des dents en 3D de la récupération des données au résultat final. Presque.





Avertissement

Cet article n'est pas éducatif dans aucun sens de ce terme et est purement informatif. L'auteur de l'article n'est pas responsable du temps passé à le lire.





A propos de l'auteur

Gentil - tout le monde, le nom est Andrey (27). J'essaierai d'ĂȘtre bref. Pourquoi programmer? Par Ă©ducation - baccalaurĂ©at en Ă©lectromĂ©canicien, je connais le mĂ©tier. J'ai travaillĂ© avec succĂšs pendant 2 ans en tant qu'ingĂ©nieur en Ă©nergie dans une sociĂ©tĂ© de forage.Au lieu d'une promotion, j'ai Ă©crit une dĂ©claration - j'ai Ă©puisĂ© le feu, mais cela ne s'est pas avĂ©rĂ© pour moi. J'aime crĂ©er, trouver des solutions Ă  des problĂšmes complexes, avec un PC dans une Ă©treinte depuis des annĂ©es conscientes. Le choix est Ă©vident. Au dĂ©but (il y a six mois), j'ai sĂ©rieusement pensĂ© Ă  m'inscrire Ă  des cours de moi ou autre. J'ai lu les critiques, discutĂ© avec les participants et rĂ©alisĂ© qu'il n'y avait aucun problĂšme pour obtenir des informations. Alors j'ai trouvĂ© le site, J'ai eu une base Python lĂ -bas et j'ai commencĂ© mon voyage avec (maintenant j'Ă©tudie progressivement tout ce qui touche au ML). ImmĂ©diatement intĂ©ressĂ© par l'apprentissage automatique, CV en particulier. J'ai rencontrĂ© un problĂšme et me voici (pour moi, c'est une excellente façon d'apprendre).





1. Introduction

À la suite de plusieurs tentatives infructueuses, j'ai pris la dĂ©cision d'utiliser 2 modĂšles lĂ©gers pour obtenir le rĂ©sultat souhaitĂ©. Le premier segment toutes les dents en tant que catĂ©gorie [1, 0], et le second les divise en catĂ©gories [0, 8]. Mais commençons dans l'ordre.





2. Recherche et préparation des données

AprĂšs avoir passĂ© plus d'une soirĂ©e Ă  chercher des donnĂ©es pour le travail, je suis arrivĂ© Ă  la conclusion qu'une mĂąchoire libre de bonne qualitĂ© et de bon format (* .stl, * .nrrd, etc.) ne fonctionnera pas. Le meilleur que j'ai rencontrĂ© Ă©tait un Ă©chantillon de test de la tĂȘte d'un patient aprĂšs une chirurgie de la mĂąchoire dans 3D Slicer .





Évidemment, je n'ai pas besoin de toute la tĂȘte, j'ai donc coupĂ© la source dans le mĂȘme programme Ă  la taille 163 * 112 * 120px (dans cet article {x * y * z = wdh} et 1px - 0,5 mm), ne laissant que le dents et parties maxillo-faciales associĂ©es.





, - . . , - "autothreshold" , , , , ( ).





- ( )? -
- ( )? -

12~14. , 4 . , .





  . Smooth 0.5. (    )
. Smooth 0.5. ( )

, ( ) , . , , N- , random-crop .





import nrrd
import torch
import torchvision.transforms as tf


class DataBuilder:
    def __init__(self,
                 data_path,
                 list_of_categories,
                 num_of_chunks: int = 0,
                 augmentation_coeff: int = 0,
                 num_of_classes: int = 0,
                 normalise: bool = False,
                 fit: bool = True,
                 data_format: int = 0,
                 save_data: bool = False
                 ):
        self.data_path = data_path
        self.number_of_chunks = num_of_chunks
        self.augmentation_coeff = augmentation_coeff
        self.list_of_cats = list_of_categories
        self.num_of_cls = num_of_classes
        self.normalise = normalise
        self.fit = fit
        self.data_format = data_format
        self.save_data = save_data

    def forward(self):
        data = self.get_data()
        data = self.fit_data(data) if self.fit else data
        data = self.pre_normalize(data) if self.normalise else data
        data = self.data_augmentation(data, self.augmentation_coeff) if self.augmentation_coeff != 0 else data
        data = self.new_chunks(data, self.number_of_chunks) if self.number_of_chunks != 0 else data
        data = self.category_splitter(data, self.num_of_cls, self.list_of_cats) if self.num_of_cls != 0 else data
        torch.save(data, self.data_path[-14:]+'.pt') if self.save_data else None

        return torch.unsqueeze(data, 1)

    def get_data(self):
        if self.data_format == 0:
            return torch.from_numpy(nrrd.read(self.data_path)[0])
        elif self.data_format == 1:
            return torch.load(self.data_path).cpu()
        elif self.data_format == 2:
            return torch.unsqueeze(self.data_path, 0).cpu()
        else:
            print('Available types are: "nrrd", "tensor" or "self.tensor(w/o load)"')

    @staticmethod
    def fit_data(some_data):
        data = torch.movedim(some_data, (1, 0), (0, -1))
        data_add_x = torch.nn.ZeroPad2d((5, 0, 0, 0))
        data = data_add_x(data)
        data = torch.movedim(data, -1, 0)
        data_add_z = torch.nn.ZeroPad2d((0, 0, 8, 0))

        return data_add_z(data)

    @staticmethod
    def pre_normalize(some_data):
        min_d, max_d = torch.min(some_data), torch.max(some_data)

        return (some_data - min_d) / (max_d - min_d)

    @staticmethod
    def data_augmentation(some_data, aug_n):
        torch.manual_seed(17)
        tr_data = []
        for e in range(aug_n):
            transform = tf.RandomRotation(degrees=(20*e, 20*e))
            for image in some_data:
                image = torch.unsqueeze(image, 0)
                image = transform(image)
                tr_data.append(image)

        return tr_data

    def new_chunks(self, some_data, n_ch):
        data = torch.stack(some_data, 0) if self.augmentation_coeff != 0 else some_data
        data = torch.squeeze(data, 1)
        chunks = torch.chunk(data, n_ch, 0)

        return torch.stack(chunks)

    @staticmethod
    def category_splitter(some_data, alpha, list_of_categories):
        data, _ = torch.squeeze(some_data, 1).to(torch.int64), alpha
        for i in list_of_categories:
            data = torch.where(data < i, _, data)
            _ += 1

        return data - alpha

      
      



3D U-net. :





  • ( ).





  • 0 168*120*120 ( 163*112*120). * .





  • 0...1 ( ~-2000...16000).





  • N- .





  • ( 1, 1, 72, 120, 120).





  • 28 (. ):





    • 1-;





    • 9 (8+) 2-.





Dataloader
import torch.utils.data as tud


class ToothDataset(tud.Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks

    def __len__(self): return len(self.images)

    def __getitem__(self, index):
        if self.masks is not None:
            return self.images[index, :, :, :, :],\
                    self.masks[index, :, :, :, :]
        else:
            return self.images[index, :, :, :, :]


def get_loaders(images, masks,
                batch_size: int = 1,
                num_workers: int = 1,
                pin_memory: bool = True):

    train_ds = ToothDataset(images=images,
                            masks=masks)

    data_loader = tud.DataLoader(train_ds,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 pin_memory=pin_memory)

    return data_loader

      
      



:









Semantic





Instance





Predictions





Data





(27*, 1, 56*, 120,120)[0...1]





(27*, 1, 56*, 120,120) [0, 1]





(1, 1, 168, 120, 120)[0...1]





Masks





(27*, 1, 56*, 120,120)[0, 1]





(27*, 1, 56*, 120,120)[0, 8]





-





* , , - .





3.

- . U-Net. , .





2D U-Net
2D U-Net

, . - Adam, Dice-loss(implement), / 4, [64, 128, 256, 512] (, , - ). 60-80 epochs . Transfer learning .





model.summary()
model = UNet(dim=2, in_channels=1, out_channels=1, n_blocks=4, start_filters=64).to(device)
print(summary(model, (1, 168, 120)))

"""
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 168, 120]             640
              ReLU-2         [-1, 64, 168, 120]               0
       BatchNorm2d-3         [-1, 64, 168, 120]             128
            Conv2d-4         [-1, 64, 168, 120]          36,928
              ReLU-5         [-1, 64, 168, 120]               0
       BatchNorm2d-6         [-1, 64, 168, 120]             128
         MaxPool2d-7           [-1, 64, 84, 60]               0
         DownBlock-8  [[-1, 64, 84, 60], [-1, 64, 168, 120]]  0
            Conv2d-9          [-1, 128, 84, 60]          73,856
             ReLU-10          [-1, 128, 84, 60]               0
      BatchNorm2d-11          [-1, 128, 84, 60]             256
           Conv2d-12          [-1, 128, 84, 60]         147,584
             ReLU-13          [-1, 128, 84, 60]               0
      BatchNorm2d-14          [-1, 128, 84, 60]             256
        MaxPool2d-15          [-1, 128, 42, 30]               0
        DownBlock-16  [[-1, 128, 42, 30], [-1, 128, 84, 60]]  0
           Conv2d-17          [-1, 256, 42, 30]         295,168
             ReLU-18          [-1, 256, 42, 30]               0
      BatchNorm2d-19          [-1, 256, 42, 30]             512
           Conv2d-20          [-1, 256, 42, 30]         590,080
             ReLU-21          [-1, 256, 42, 30]               0
      BatchNorm2d-22          [-1, 256, 42, 30]             512
        MaxPool2d-23          [-1, 256, 21, 15]               0
        DownBlock-24  [[-1, 256, 21, 15], [-1, 256, 42, 30]]  0
           Conv2d-25          [-1, 512, 21, 15]       1,180,160
             ReLU-26          [-1, 512, 21, 15]               0
      BatchNorm2d-27          [-1, 512, 21, 15]           1,024
           Conv2d-28          [-1, 512, 21, 15]       2,359,808
             ReLU-29          [-1, 512, 21, 15]               0
      BatchNorm2d-30          [-1, 512, 21, 15]           1,024
        DownBlock-31  [[-1, 512, 21, 15], [-1, 512, 21, 15]]  0
  ConvTranspose2d-32          [-1, 256, 42, 30]         524,544
             ReLU-33          [-1, 256, 42, 30]               0
      BatchNorm2d-34          [-1, 256, 42, 30]             512
      Concatenate-35          [-1, 512, 42, 30]               0
           Conv2d-36          [-1, 256, 42, 30]       1,179,904
             ReLU-37          [-1, 256, 42, 30]               0
      BatchNorm2d-38          [-1, 256, 42, 30]             512
           Conv2d-39          [-1, 256, 42, 30]         590,080
             ReLU-40          [-1, 256, 42, 30]               0
      BatchNorm2d-41          [-1, 256, 42, 30]             512
          UpBlock-42          [-1, 256, 42, 30]               0
  ConvTranspose2d-43          [-1, 128, 84, 60]         131,200
             ReLU-44          [-1, 128, 84, 60]               0
      BatchNorm2d-45          [-1, 128, 84, 60]             256
      Concatenate-46          [-1, 256, 84, 60]               0
           Conv2d-47          [-1, 128, 84, 60]         295,040
             ReLU-48          [-1, 128, 84, 60]               0
      BatchNorm2d-49          [-1, 128, 84, 60]             256
           Conv2d-50          [-1, 128, 84, 60]         147,584
             ReLU-51          [-1, 128, 84, 60]               0
      BatchNorm2d-52          [-1, 128, 84, 60]             256
          UpBlock-53          [-1, 128, 84, 60]               0
  ConvTranspose2d-54         [-1, 64, 168, 120]          32,832
             ReLU-55         [-1, 64, 168, 120]               0
      BatchNorm2d-56         [-1, 64, 168, 120]             128
      Concatenate-57        [-1, 128, 168, 120]               0
           Conv2d-58         [-1, 64, 168, 120]          73,792
             ReLU-59         [-1, 64, 168, 120]               0
      BatchNorm2d-60         [-1, 64, 168, 120]             128
           Conv2d-61         [-1, 64, 168, 120]          36,928
             ReLU-62         [-1, 64, 168, 120]               0
      BatchNorm2d-63         [-1, 64, 168, 120]             128
          UpBlock-64         [-1, 64, 168, 120]               0
           Conv2d-65          [-1, 1, 168, 120]              65
================================================================
Total params: 7,702,721
Trainable params: 7,702,721
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.08
Forward/backward pass size (MB): 7434.08
Params size (MB): 29.38
Estimated Total Size (MB): 7463.54
"""
      
      



.№1
2D U-Net,   ,  [x, z]
.№1 2D U-Net, , [x, z]

, - . , . numpy - *.stl 6. , :





  :
1.  [x, y]. 2.  [x, z]. 3. [y, z]
: 1. [x, y]. 2. [x, z]. 3. [y, z]

100% , ? , .





, , , , , .





.№2
 2- 2D U-Net,   ,  [y, z]
.№2 2- 2D U-Net, , [y, z]

, , :





.№3
 2- 2D U-Net,     [y, z]
     50%
.№3 2- 2D U-Net, [y, z] 50%

3D . , (24*, 120, 120). ? - (~22. ). (1063gtx) .





24*

. :





  • (1512, 120, 120) - 63;





  • batch size (24, 120, 120) - , ;





  • (24) / ( 24/2/2/2=3 3*2*2*2=24, / 2 / 1);





  • , . .summary()





model.summary()
model = UNet(dim=3, in_channels=1, out_channels=1, n_blocks=4, start_filters=64).to(device)
print(summary(model, (1, 24, 120, 120)))

"""
  ----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1     [-1, 64, 24, 120, 120]             1,792
              ReLU-2     [-1, 64, 24, 120, 120]                 0
       BatchNorm3d-3     [-1, 64, 24, 120, 120]               128
            Conv3d-4     [-1, 64, 24, 120, 120]           110,656
              ReLU-5     [-1, 64, 24, 120, 120]                 0
       BatchNorm3d-6     [-1, 64, 24, 120, 120]               128
         MaxPool3d-7        [-1, 64, 12, 60, 60]                0
         DownBlock-8  [[-1, 64, 12, 60, 60], [-1, 64, 24, 120, 120]]               0
            Conv3d-9       [-1, 128, 12, 60, 60]          221,312
             ReLU-10       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-11       [-1, 128, 12, 60, 60]              256
           Conv3d-12       [-1, 128, 12, 60, 60]          442,496
             ReLU-13       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-14       [-1, 128, 12, 60, 60]              256
        MaxPool3d-15       [-1, 128, 6, 30, 30]                 0
        DownBlock-16  [[-1, 128, 6, 30, 30], [-1, 128, 12, 60, 60]]               0
           Conv3d-17       [-1, 256, 6, 30, 30]           884,992
             ReLU-18       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-19       [-1, 256, 6, 30, 30]               512
           Conv3d-20       [-1, 256, 6, 30, 30]         1,769,728
             ReLU-21       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-22       [-1, 256, 6, 30, 30]               512
        MaxPool3d-23       [-1, 256, 3, 15, 15]                 0
        DownBlock-24  [[-1, 256, 3, 15, 15], [-1, 256, 6, 30, 30]]               0
           Conv3d-25       [-1, 512, 3, 15, 15]         3,539,456
             ReLU-26       [-1, 512, 3, 15, 15]                 0
      BatchNorm3d-27       [-1, 512, 3, 15, 15]             1,024
           Conv3d-28       [-1, 512, 3, 15, 15]         7,078,400
             ReLU-29       [-1, 512, 3, 15, 15]                 0
      BatchNorm3d-30       [-1, 512, 3, 15, 15]             1,024
        DownBlock-31  [[-1, 512, 3, 15, 15], [-1, 512, 3, 15, 15]]               0
  ConvTranspose3d-32       [-1, 256, 6, 30, 30]         1,048,832
             ReLU-33       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-34       [-1, 256, 6, 30, 30]               512
      Concatenate-35       [-1, 512, 6, 30, 30]                 0
           Conv3d-36       [-1, 256, 6, 30, 30]         3,539,200
             ReLU-37       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-38       [-1, 256, 6, 30, 30]               512
           Conv3d-39       [-1, 256, 6, 30, 30]         1,769,728
             ReLU-40       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-41       [-1, 256, 6, 30, 30]               512
          UpBlock-42       [-1, 256, 6, 30, 30]                 0
  ConvTranspose3d-43       [-1, 128, 12, 60, 60]          262,272
             ReLU-44       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-45       [-1, 128, 12, 60, 60]              256
      Concatenate-46       [-1, 256, 12, 60, 60]                0
           Conv3d-47       [-1, 128, 12, 60, 60]          884,864
             ReLU-48       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-49       [-1, 128, 12, 60, 60]              256
           Conv3d-50       [-1, 128, 12, 60, 60]          442,496
             ReLU-51       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-52       [-1, 128, 12, 60, 60]              256
          UpBlock-53       [-1, 128, 12, 60, 60]                0
  ConvTranspose3d-54       [-1, 64, 24, 120, 120]          65,600
             ReLU-55       [-1, 64, 24, 120, 120]               0
      BatchNorm3d-56       [-1, 64, 24, 120, 120]             128
      Concatenate-57      [-1, 128, 24, 120, 120]               0
           Conv3d-58       [-1, 64, 24, 120, 120]         221,248
             ReLU-59       [-1, 64, 24, 120, 120]               0
      BatchNorm3d-60       [-1, 64, 24, 120, 120]             128
           Conv3d-61       [-1, 64, 24, 120, 120]         110,656
             ReLU-62       [-1, 64, 24, 120, 120]               0
      BatchNorm3d-63       [-1, 64, 24, 120, 120]             128
          UpBlock-64       [-1, 64, 24, 120, 120]               0
           Conv3d-65        [-1, 1, 24, 120, 120]              65
================================================================
Total params: 22,400,321
Trainable params: 22,400,321
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.61
Forward/backward pass size (MB): 15974.12
Params size (MB): 85.45
Estimated Total Size (MB): 16060.18
----------------------------------------------------------------
"""
      
      



.№4
3D U-Net,  ,  [y, z],
*0,38
.№4 3D U-Net, , [y, z], *0,38

~60% (25 epochs) , .





.№5
3D U-Net,  ,  [y, z], 
65 epochs ~ 1,5
.№5 3D U-Net, , [y, z], 65 epochs ~ 1,5

. , (.№3) - :





.№6
3D U-Net,  ,  [x, z], 
105 epochs ~ 2,1
.№6 3D U-Net, , [x, z], 105 epochs ~ 2,1

"" . ~400 ( ~22) [18, 32, 64, 128] / 3. RSMProp. (1, 1, 72*, 120, 120). ?





model.summary()
model = UNet(dim=3, in_channels=1, out_channels=1, n_blocks=3, start_filters=18).to(device)
print(summary(model, (1, 1, 72, 120, 120)))

"""
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1     [-1, 18, 72, 120, 120]             504
              ReLU-2     [-1, 18, 72, 120, 120]               0
       BatchNorm3d-3     [-1, 18, 72, 120, 120]              36
            Conv3d-4     [-1, 18, 72, 120, 120]           8,766
              ReLU-5     [-1, 18, 72, 120, 120]               0
       BatchNorm3d-6     [-1, 18, 72, 120, 120]              36
         MaxPool3d-7       [-1, 18, 36, 60, 60]               0
         DownBlock-8  [[-1, 18, 36, 60, 60], [-1, 18, 24, 120, 120]]               0
            Conv3d-9       [-1, 36, 36, 60, 60]          17,532
             ReLU-10       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-11       [-1, 36, 36, 60, 60]              72
           Conv3d-12       [-1, 36, 36, 60, 60]          35,028
             ReLU-13       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-14       [-1, 36, 36, 60, 60]              72
        MaxPool3d-15        [-1, 36, 18, 30, 30]              0
        DownBlock-16  [[-1, 36, 18, 30, 30], [-1, 36, 36, 60, 60]]               0
           Conv3d-17        [-1, 72, 18, 30, 30]         70,056
             ReLU-18        [-1, 72, 18, 30, 30]              0
      BatchNorm3d-19        [-1, 72, 18, 30, 30]            144
           Conv3d-20        [-1, 72, 18, 30, 30]        140,040
             ReLU-21        [-1, 72, 18, 30, 30]              0
      BatchNorm3d-22        [-1, 72, 18, 30, 30]            144
        DownBlock-23  [[-1, 72, 18, 30, 30], [-1, 72, 18, 30, 30]]               0
  ConvTranspose3d-24       [-1, 36, 36, 60, 60]          20,772
             ReLU-25       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-26       [-1, 36, 36, 60, 60]              72
      Concatenate-27       [-1, 72, 36, 60, 60]               0
           Conv3d-28       [-1, 36, 36, 60, 60]          70,020
             ReLU-29       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-30       [-1, 36, 36, 60, 60]              72
           Conv3d-31       [-1, 36, 36, 60, 60]          35,028
             ReLU-32       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-33       [-1, 36, 36, 60, 60]              72
          UpBlock-34       [-1, 36, 36, 60, 60]               0
  ConvTranspose3d-35     [-1, 18, 72, 120, 120]           5,202
             ReLU-36     [-1, 18, 72, 120, 120]               0
      BatchNorm3d-37     [-1, 18, 72, 120, 120]              36
      Concatenate-38     [-1, 36, 72, 120, 120]               0
           Conv3d-39     [-1, 18, 72, 120, 120]          17,514
             ReLU-40     [-1, 18, 72, 120, 120]               0
      BatchNorm3d-41     [-1, 18, 72, 120, 120]              36
           Conv3d-42     [-1, 18, 72, 120, 120]           8,766
             ReLU-43     [-1, 18, 72, 120, 120]               0
      BatchNorm3d-44     [-1, 18, 72, 120, 120]              36
          UpBlock-45     [-1, 18, 72, 120, 120]               0
           Conv3d-46      [-1, 1, 72, 120, 120]              19
================================================================
Total params: 430,075
Trainable params: 430,075
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.32
Forward/backward pass size (MB): 5744.38
Params size (MB): 1.64
Estimated Total Size (MB): 5747.34
----------------------------------------------------------------
"""
      
      



72*

, (168, 120, 120), (72, 120, 120). , . , 2 , . 9 (1512, 120, 120) .. 9 , 21(batch size) (72, 120, 120). 72 , 24*().





.№7
3D U-Net,  ,  [x, z],
 ()    (),
  ,
 (65 epochs) ~ 14.
.№7 3D U-Net, , [x, z], () (), , (65 epochs) ~ 14.

, ( "" ). , . semantic segmentation , .





3D ( ) (1512, 120, 120) --> 21*(1, 72, 120, 120), ~*(30, 30, 30) ( ). 2 : 3- , ( ); , .





, 1 epochs "" ~13, 2 (>80). 1 epochs. , .





. 8 + . loss function .





training loop
import torch
from tqdm import tqdm
from _loss_f import LossFunction


class TrainFunction:
    def __init__(self,
                 data_loader,
                 device_for_training,
                 model_name,
                 model_name_pretrained,
                 model,
                 optimizer,
                 scale,
                 learning_rate: int = 1e-2,
                 num_epochs: int = 1,
                 transfer_learning: bool = False,
                 binary_loss_f: bool = True
                 ):
        self.data_loader = data_loader
        self.device = device_for_training
        self.model_name_pretrained = model_name_pretrained
        self.semantic_binary = binary_loss_f
        self.num_epochs = num_epochs
        self.model_name = model_name
        self.transfer = transfer_learning
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.model = model
        self.scale = scale

    def forward(self):
        print('Running on the:', torch.cuda.get_device_name(self.device))
        self.model.load_state_dict(torch.load(self.model_name_pretrained)) if self.transfer else None
        optimizer = self.optimizer(self.model.parameters(), lr=self.learning_rate)
        for epoch in range(self.num_epochs):
            self.train_loop(self.data_loader, self.model, optimizer, self.scale, epoch)
            torch.save(self.model.state_dict(), 'models/' + self.model_name+str(epoch+1)
                       + '_epoch.pth') if (epoch + 1) % 10 == 0 else None

    def train_loop(self, loader, model, optimizer, scales, i):
        loop, epoch_loss = tqdm(loader), 0
        loop.set_description('Epoch %i' % (self.num_epochs - i))
        for batch_idx, (data, targets) in enumerate(loop):
            data, targets = data.to(device=self.device, dtype=torch.float), \
                            targets.to(device=self.device, dtype=torch.long)
            optimizer.zero_grad()
            * *
            with torch.cuda.amp.autocast():
                predictions = model(data)
                loss = LossFunction(predictions, targets,
                                    device_for_training=self.device,
                                    semantic_binary=self.semantic_binary
                                    ).forward()
            scales.scale(loss).backward()
            scales.step(optimizer)
            scales.update()
            epoch_loss += (1 - loss.item())*100
            loop.set_postfix(loss=loss.item())
        print('Epoch-acc', round(epoch_loss / (batch_idx+1), 2))

      
      



4.

Dice-loss , '' [0, 1]. , ( [0, 1]), ( "" "" ) Dice-loss , .





categorical_dice_loss
import torch


class LossFunction:
    def __init__(self,
                 prediction,
                 target,
                 device_for_training,
                 semantic_binary: bool = True,
                 ):
        self.prediction = prediction
        self.device = device_for_training
        self.target = target
        self.semantic_binary = semantic_binary

    def forward(self):
        if self.semantic_binary:
            return self.dice_loss(self.prediction, self.target)
        return self.categorical_dice_loss(self.prediction, self.target)

    @staticmethod
    def dice_loss(predictions, targets, alpha=1e-5):
        intersection = 2. * (predictions * targets).sum()
        denomination = (torch.square(predictions) + torch.square(targets)).sum()
        dice_loss = 1 - torch.mean((intersection + alpha) / (denomination + alpha))

        return dice_loss

    def categorical_dice_loss(self, prediction, target):
        pr, tr = self.prepare_for_multiclass_loss_f(prediction, target)
        target_categories, losses = torch.unique(tr).tolist(), 0
        for num_category in target_categories:
            categorical_target = torch.where(tr == num_category, 1, 0)
            categorical_prediction = pr[num_category][:][:][:]
            losses += self.dice_loss(categorical_prediction, categorical_target).to(self.device)

        return losses / len(target_categories)

    @staticmethod
    def prepare_for_multiclass_loss_f(prediction, target):
        prediction_prepared = torch.squeeze(prediction, 0)
        target_prepared = torch.squeeze(target, 0)
        target_prepared = torch.squeeze(target_prepared, 0)

        return prediction_prepared, target_prepared

      
      



, "categorical_dice_loss":





  • ( );





  • , batch ;





  • "" "" , [0, 1] Dice-loss;





  • , batct. .





, , one-hot , ( ), , . , , , . (5).





5.

".. ". *.nrrd .





import nrrd
#   numpy
read = nrrd.read(data_path) 
data, meta_data = read[0], read[1]

print(data.shape, np.max(data), np.min(data), meta_data, sep="\n")

(163, 112, 120)
14982
-2254 
 OrderedDict([('type', 'short'), ('dimension', 3), ('space', 'left-posterior-superior'), ('sizes', array([163, 112, 120])), ('space directions', array([[-0.5,  0. ,  0. ],
       [ 0. , -0.5,  0. ],
       [ 0. ,  0. ,  0.5]])), ('kinds', ['domain', 'domain', 'domain']), ('endian', 'little'), ('encoding', 'gzip'), ('space origin', array([131.57200623,  80.7661972 ,  32.29940033]))])
      
      



- , ? , , , .





, 8 12 . ( ) - ( 3- ) . , , "" -1 , ..





Ça a l'air aussi fou que ça en a l'air
,

- , . , . Skimage Stl.





from skimage.measure import marching_cubes
import nrrd
import numpy as np
from stl import mesh

path = 'some_path.nrrd'
data = nrrd.read(path)[0]


def three_d_creator(some_data):
    vertices, faces, volume, _ = marching_cubes(some_data)
    cube = mesh.Mesh(np.full(faces.shape[0], volume.shape[0], dtype=mesh.Mesh.dtype))
    for i, f in enumerate(faces):
        for j in range(3):
            cube.vectors[i][j] = vertices[f[j]]
    cube.save('name.stl')

    return cube


stl = three_d_creator(datas)
      
      



, "" . , , Win 10 3D Builder - . "" 3D . " " .





v3do. , , .





npy stl
from vedo import Volume, show, write

prediction = 'some_data_path.npy'

def show_save(data, save=False):
    data_multiclass = Volume(data, c='Set2', alpha=(0.1, 1), alphaUnit=0.87, mode=1)
    data_multiclass.addScalarBar3D(nlabels=9)
    show([(data_multiclass, "Multiclass teeth segmentation prediction")], bg='black', N=1, axes=1).close()
    write(data_multiclass.isosurface(), 'some_name_.stl') if save else None
    
show_save(prediction, save=True)
      
      



.





. :





model.summary()
model = UNet(dim=3, in_channels=1, out_channels=9, n_blocks=3, start_filters=9).to(device)
print(summary(model, (1, 168*, 120, 120)))
    
"""
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1      [-1, 9, 168, 120, 120]            252
              ReLU-2      [-1, 9, 168, 120, 120]              0
       BatchNorm3d-3      [-1, 9, 168, 120, 120]             18
            Conv3d-4      [-1, 9, 168, 120, 120]          2,196
              ReLU-5      [-1, 9, 168, 120, 120]              0
       BatchNorm3d-6      [-1, 9, 168, 120, 120]             18
         MaxPool3d-7        [-1, 9, 84, 60, 60]               0
         DownBlock-8  [[-1, 9, 84, 60, 60], [-1, 9, 168, 120, 120]]               0
            Conv3d-9       [-1, 18, 84, 60, 60]           4,392
             ReLU-10       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-11       [-1, 18, 84, 60, 60]              36
           Conv3d-12       [-1, 18, 84, 60, 60]           8,766
             ReLU-13       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-14       [-1, 18, 84, 60, 60]              36
        MaxPool3d-15       [-1, 18, 42, 30, 30]               0
        DownBlock-16  [[-1, 18, 18, 42, 30], [-1, 18, 84, 60, 60]]               0
           Conv3d-17       [-1, 36, 42, 30, 30]          17,532
             ReLU-18       [-1, 36, 42, 30, 30]               0
      BatchNorm3d-19       [-1, 36, 42, 30, 30]              72
           Conv3d-20       [-1, 36, 42, 30, 30]          35,028
             ReLU-21       [-1, 36, 42, 30, 30]               0
      BatchNorm3d-22       [-1, 36, 42, 30, 30]              72
        DownBlock-23  [[-1, 36, 42, 30, 30], [-1, 36, 42, 30, 30]]               0
  ConvTranspose3d-24       [-1, 18, 84, 60, 60]           5,202
             ReLU-25       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-26       [-1, 18, 84, 60, 60]              36
      Concatenate-27       [-1, 36, 84, 60, 60]               0
           Conv3d-28       [-1, 18, 84, 60, 60]          17,514
             ReLU-29       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-30       [-1, 18, 84, 60, 60]              36
           Conv3d-31       [-1, 18, 84, 60, 60]           8,766
             ReLU-32       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-33       [-1, 18, 84, 60, 60]              36
          UpBlock-34       [-1, 18, 84, 60, 60]               0
  ConvTranspose3d-35      [-1, 9, 168, 120, 120]          1,305
             ReLU-36      [-1, 9, 168, 120, 120]              0
      BatchNorm3d-37      [-1, 9, 168, 120, 120]             18
      Concatenate-38     [-1, 18, 168, 120, 120]              0
           Conv3d-39      [-1, 9, 168, 120, 120]          4,383
             ReLU-40      [-1, 9, 168, 120, 120]              0
      BatchNorm3d-41      [-1, 9, 168, 120, 120]             18
           Conv3d-42      [-1, 9, 168, 120, 120]          2,196
             ReLU-43      [-1, 9, 168, 120, 120]              0
      BatchNorm3d-44      [-1, 9, 168, 120, 120]             18
          UpBlock-45      [-1, 9, 168, 120, 120]              0
           Conv3d-46      [-1, 9, 168, 120, 120]             90
================================================================
Total params: 108,036
Trainable params: 108,036
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.96
Forward/backward pass size (MB): 12170.30
Params size (MB): 0.41
Estimated Total Size (MB): 12174.66
----------------------------------------------------------------
    """
      
      



* ([9, 18, 36, 72]), - 9*(168, 120, 120)





Exp. N ° 8 Segmentation intermédiaire en 8 catégories
.№8 8

, , . ? - "" 8- , . , 12 (GPU) .





Exp. N ° 9 Segmentation complÚte
.№9

6. After words

, , - . . , , 2 , . , ? , , 28 , , "" / ? U-net GCNN Pytorch - Pytorch3D? , , bounding box( 1 ). , , .





()
" "
Un exemple de graphe non orienté pour 28 catégories avec des "délimiteurs"
28 ""

Un merci spécial à ma femme, Alena, pour son soutien particulier lors de cette "plongée dans les ténÚbres".





Merci Ă  tous pour votre attention. Les critiques et suggestions constructives, tant les corrections que les nouveaux projets, sont les bienvenues.








All Articles