top of page

Artificial Intelligence for Geospatial Analysis with Pytorch’s TorchGeo (Part 3)

Writer's picture: Mauricio CordeiroMauricio Cordeiro

An end-to-end deep learning geospatial segmentation project using Pytorch and TorchGeo packages


Photo by NASA on Unsplash

Introduction

In the previous stories (Part 1 and Part 2), we saw how to prepare a dataset of raster (multispectral) images and combined them with the corresponding labels (ground truth masks) using the IntersectionDataset provided by TorchGeo. To draw samples from it (smaller fixed-sized patches required for training), the RandomGeoSampler was used altogether with the DataLoader object (responsible for providing the batches — group of samples — to the training procedure).

In addition, we added spectral indices and normalization to each batch using the nn.Sequential class. Now, in this last part, we will see how to create a model that is capable of “learning” to correctly segment our images and how to put everything together into a training loop.

So let’s get started!

Baseline

First, we need to catch up from where we stopped and prepare the Datasets and Dataloaders for the task. Basically, we will need to create two Dataloaders, for training and for validation, following the same procedures as before, as below.

Datasets and DataLoaders



The length specified in the sampler is the number of samples provided in one “pass” (also called one epoch). This is normally the number of patches in the dataset and one epoch should go through all the samples. However, as we are working with a random sampler, we cannot guarantee we are covering all the region. In this case, I defined the length as being four times the number of images in each dataset (train and validation).

Now, let’s create the DataLoaders and check if they are working as expected.


Standardization and Spectral indices

For standardization and spectral indices, the procedure is the same that has already been presented in Parts 1 & 2. It is the same for the visualization routines. The following notebook has everything updated up to the correct batch creation.

The last cell shows a sample of a validation dataset batch and the shape of the validation image with 9 channels (6 channels + 3 indices), as expected.



Figure 1: Code output. Image by author.

Segmentation Model

For the semantic segmentation model, we are going to use a predefined architecture that is available in Pytorch. Looking at the official documentation (https://pytorch.org/vision/stable/models.html#semantic-segmentation) it is possible to note 3 models available for semantic segmentation, but one (LRASPP) is intended for mobile applications. In our tutorial, we will use the DeepLabV3 model.

So, let’s create a DeepLabV3 model for 2 classes. In this case, I will skip the pretrained weights, as the weights represent another domain (not water segmentation from multispectral imagery).

from torchvision.models.segmentation import deeplabv3_resnet50
model = deeplabv3_resnet50(weights=None, num_classes=2)

model
code output: 
DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
...

The first thing we have to pay attention in the model architecture is the number of channels expected in the first convolution (Conv2d), that is defined as 3. That’s because the model is prepared to work with RGB images. After the first convolution, the 3 channels will produce 64 channels in lower resolution, and so on. As we have now 9 channels, we will change this first processing layer to adapt correctly to our model. We can do this by replacing the first convolutional layer for a new one, by following the commands. Finally, we check a mock batch can pass through the model and provide the output with 2 channels (water / no_water) as desired.

backbone = model.get_submodule('backbone')
conv = nn.modules.conv.Conv2d(
    in_channels=9,
    out_channels=64,
    kernel_size=(7, 7),
    stride=(2, 2),
    padding=(3, 3),
    bias=False
)
backbone.register_module('conv1', conv)
pred = model(torch.randn(3, 9, 512, 512))
pred['out'].shape
code output: 
torch.Size([3, 2, 512, 512])

Our architecture seems to be working as expected. The next step is to train it. So let’s create a training loop for it.

Training loop

The training function should receive the number of epochs, the model, the dataloaders, the loss function (to be optimized) the accuracy function (to assess the results), the optimizer (that will adjust the parameters of the model in the correct direction) and the transformations to be applied to each batch.


Loss and Accuracy functions

Before calling the training function, let’s create the loss and accuracy functions. In our specific case, we shall have predictions with shape (N, C, d1, d2) and we have masks with the shape (N, 1, d1, d2). For the loss function, normally the Cross Entropy Loss should work, but it requires the mask to have shape (N, d1, d2). In this case, we will need to squeeze our second dimension manually.

Additionally, we will create two accuracy functions. The overall accuracy, used in the original paper and the intersect over union. Usually when we have masks with unbalanced amount pixels in each class, as it is the case for water masks (sometimes we have scenes with just land and very few water bodies), the overall accuracy will result in unrealistic values. In this case, the OA should be avoided, but it is left here for comparison with the original paper.

The overall accuracy is calculated manually by adding all the matches and dividing by the number of elements in the batch. The IoU is also known as Jaccard Index and it is available in Sklearn package. The Pytorch’s cross entropy is used for loss, with a minor adjustment in the target’s shape. After all the necessary adjustments the functions are defined as:


Training

The training function can now be called like so:

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)

train_loop(5, train_dataloader, valid_dataloader, model, loss, optimizer,            acc_fns=[oa, iou], batch_tfms=tfms)
code output: 
Epoch 0: Train Loss=0.37275 - Accs=[0.921, 0.631]
Epoch 1: Train Loss=0.22578 - Accs=[0.94, 0.689]
Epoch 2: Train Loss=0.22280 - Accs=[0.906, 0.576]
Epoch 3: Train Loss=0.19370 - Accs=[0.944, 0.706]
Epoch 4: Train Loss=0.18241 - Accs=[0.92, 0.619]
Epoch 5: Train Loss=0.21393 - Accs=[0.956, 0.748]

From the results, we can see that the loss is dropping and accuracy increasing. So our training is working as expected. The first accuracy is the overall accuracy and the second is the IoU. Comparing our results after 10 epochs with the results obtained from the DeepLabV3 tested in the research paper (Luo et al. 2021), we have OA=95.6% and OA = 95.7% (mean obtained from the 3 distinct regions considered in the paper) respectively. Considering we started from arbitrary weights and did not performed any fine tuning on hyperparameters, such as regularization or learning rate, etc., we could say our results are very good. It would be interesting to compare this dataset with other water segmentation algorithms, as single index thresholding (MNDWI, AWEI, etc.) do not provide the best results, despite their simplicity (Cordeiro et al. 2021).

Notebook

The complete notebook is available here and it can be opened directly in google Colab.



Conclusions

In this Part 3 we’ve finished our project by providing a training loop to optimize the DL model (DeepLab V3) for the task of water segmentation using satellite imagery. The results are promising, but they could be improved by some fine tuning. Besides hyperparameters and more training , data augmentation could also be used to improve accuracy, as well as a different architecture, such as U-Net. It would also be interesting to check the quality of the outputs visually, to understand where the model is performing well and where it is missing the target. These topics were not covered in this story, but if you would like to see more stories like this, don’t hesitate in leaving your requests (and thoughts) in the comments.

Previous parts



GeoCorner - Subscribe to our mailing list.
Sign up. Be inspired. Code.

Join the GeoCorner Community!

Subscribe to our newsletter for the latest insights, tips, and trends in geospatial technology and coding. Be at the forefront of geospatial innovation – right in your inbox.

bottom of page