Layer Ensembles: A Single-Pass Uncertainty Estimation in Deep Learning for Segmentation

Layer Ensembles: A Single-Pass Uncertainty Estimation in Deep Learning for Segmentation

Layer Ensembles is a method for uncertainty estimation in deep learning using a single network and a single pass (see this paper).

Description

Method

The main idea of Layer Ensembles is to attach an output head to intermediate layers of a network at different depths like in the following example:

Layer Ensembles Architecture

Then, the outputs can be combined as an ensemble of networks of different depths. We tested this idea for two medical image segmentation tasks: 1) binary – breast mass segmentation; and 2) multi-class – cardiac structure segmentation.

AULA - Area Under Layer Agreement curve is a new image-level uncertainty metric that is derived from the sequential nature of the Layer Ensembles method.

It is a known phenomenon in Deep Learning that early layers generalise and deep layers memorise. So the easy samples (i.e. most common and in-distribution) will be already adequately segmented in the early layers and in all consecutive layers – drawing a high agreement between all the adjacent layers. However, these early layers will have much less agreement for the difficult samples (i.e. less common and out-of-distribution).

Look at this example for mass segmentation (blue → prediction; green → ground truth):

AULA illustration

In (a), we have a high contrast lesion with clear boundaries to the surrounding tissue. Hence, we observe a high agreement between the segmentation predictions for all the layers. We calculate the agreement between two layers using the Dice Coefficient. Then, we can calculate the Area Under Layer Agreement curve as shown in the shaded blue area.

In (b), we have a mass that has a low contrast to the surrounding tissue and the calcification pathology is present. This example is one of the difficult samples. We can see the agreement in the early layers is low. Hence, low AULA.

In the results below, you can see that AULA correlates well with the segmentation performance metrics. I.e., high AULA for good segmentation and low AULA for poor segmentation. This allows detecting failed segmentation automatically.

How to apply Layer Ensembles to your model

  1. Load any model
from torchvision.models import resnet18
architecture = resnet18(weights=None, num_classes=2)
  1. Import the LayerEnsembles wrapper and the task Enum (e.g., segmentation, classification, regression)
from methods.layer_ensembles import LayerEnsembles
from utils import Task

This is an example for importing them in main.py. At the moment, we only support Task.SEGMENTATION.

  1. Get the names of all the layers in your model
all_layers = dict([*architecture.named_modules()])
intermediate_layers = []
for name, layer in all_layers.items():
    if '.relu' in name:
        intermediate_layers.append(name)

You can change .relu to any other component e.g., .bn or .conv. The . is to include only sub-modules (exclude stem).

  1. Init LayerEnsembles with the names of the intermediate layers to use as outputs
model = LayerEnsembles(architecture, intermediate_layers)
# Dummy input to get the output shapes of the layers
x = torch.randn(1, 1, 128, 128)
output = model(x)
out_channels = []
for key, val in output.items():
    out_channels.append(val.shape[1])
# Set the output heads with the number of channels of the output layers
model.set_output_heads(in_channels=out_channels, task=Task.SEGMENTATION, classes=2)
  1. Check the output shapes
outputs = model(x)
print(len(outputs))
for layer, out in outputs.items():
    print(layer, out.shape)
  1. Training goes as usual, but note that outputs is a dictionary with tensor values corresponding for each output head name as keys. Thus, we calculate the total_loss as the sum of each output heads and then backpropagate.
model.train()
total_loss = 0
outputs = model(x)
losses = [criterion(output, target) for _, output in outputs.items()]
for loss in losses:
    total_loss = total_loss + loss
total_loss.backward()
optimizer.step()

Feel free to modify the loss functions and how the total loss is calculated.

  1. In testing, the output list contains predictions from each head. You can combine them in any way you like (e.g., averaging, STAPLE).

Results

Segmentation and Calibration

We compare the segmentation and uncertainty estimation performance against the state-of-the-art Deep Ensembles (DE), where \(M\) randomly initialised models are trained with the same data and their outputs are combined to get the final prediction. Also, we compare Layer Ensembles (LE) with another single pass and single network approach Multi-Head Ensembles method (MH) as well as a Plain model without any uncertainty estimation (referred as Plain).

Method DSC ↑ MHD ↓ NLL ↓ DSC ↑ MHD ↓ NLL ↓
Plain 0.865 1.429 2.312 0.900 1.061 0.182
MH 0.865 1.457 2.191 0.892 1.418 0.228
DE 0.870 1.373 0.615 0.896 1.465 0.157
LE 0.872 1.317 0.306 0.903 1.302 0.173

Qualitative examples

Here are some examples of segmentation outputs and uncertainty heatmaps based on variance. DE's uncertainty maps are overconfident, while LE manages to highlight the difficult areas.

Qualitative results

Figure 1. Examples of visual uncertainty heatmaps based on variance for high uncertainty areas (red arrows) using LE (top) and DE (bottom) for breast mass and cardiac structure segmentation. Black and green contours correspond to ground truth.

Correlation of uncertainty metrics with segmentation performance

Entropy MI Variance AULA Entropy MI Variance AULA
DE-DSC -0.783 -0.785 -0.785 N/A -0.323 -0.433 -0.377 N/A
LE-DSC -0.615 -0.597 -0.620 0.785 -0.221 -0.207 -0.203 0.649
DE-MHD 0.762 0.764 0.763 N/A 0.401 0.499 0.447 N/A
LE-MHD 0.594 0.575 0.598 -0.730 0.309 0.313 0.300 -0.571

We can see here a very good correlation of the AULA metric compared to the DE's Variance, MI, and Entropy uncertainty metrics. We can detect poor segmentation samples automatically and ask the clinicians to manually revise them:

Quality control

Figure 2. Segmentation quality control for DE and LE. The following are averaged indicators for: random flagging (dashed black); remaining 5% of poor segmentations (dotted grey); and ideal line (grey shaded area).

Tuning the Negative Log Likelihood

Since we attach the output heads sequentially, we can skip the early layers outputs to tune the NLL to our needs:

NLL tuning

Figure 3. The effect of skipping initial segmentation head outputs on model calibration. Numbers on top of the lines represent DSC in mean(std) format. Shaded areas are standard deviations for NLL.

How to cite

@inproceedings{kushibar2022layer,
  title={Layer Ensembles: A Single-Pass Uncertainty Estimation in Deep Learning for Segmentation},
  author={Kushibar, Kaisar and Campello, Victor and Garrucho, Lidia and Linardos, Akis and Radeva, Petia and Lekadir, Karim},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={514--524},
  year={2022},
  organization={Springer}
}