An end-to-end deep learning geospatial segmentation project using Pytorch and TorchGeo packages
Introduction
In the previous story (Part 1 — here), we saw how to create a RasterDataset using TorchGeo and a RandomSampler to draw patches from it. In the current story, we are moving further and we will concatenate both the images and the masks into an object called IntersectionDataset, to draw tuples with training patches (images) and corresponding labels (water masks).
So let’s get started.
Datasets
Before continuing, we will replicate the code from the previous story just to have the environment ready in Colab. After that, we will follow the same procedure we used to create the RasterDataset for the images, in order to create the dataset for the masks. One point we need to pay attention is to inform the RasterDataset class that our masks are not “images”. This way, when a sample is drawn from the masks dataset, it will return the data coupled with the “mask” key, instead of the “image” key, normally used in the images. Note that the code output from our gist, show the keys for each sample.
OBS: Another difference from the previous story is that here we are passing a transform function to the images dataset to scale correctly the values to reflectance values (division by 10,000).
Additionally, we can check the bounding box is the same, so both samples are from the same geographic region. Here is the notebook:
Once we have both datasets (images and masks) set up, we can combine both in a very convenient way, like so:
train_dset = train_imgs & train_msks
Now, when we draw samples from this new dataset, the resulting directory shall has an entry with the image data (already scaled by 10,000) and an entry with the masks.
sample = train_dset[bbox]
sample.keys()
output:
dict_keys(['image', 'crs', 'bbox', 'mask'])
DataLoaders
Creating a DataLoader in TorchGeo is very straightforward, just like it is with Pytorch (we are actually using the same class). Note below that we are also using the same sampler already defined. Additionally we inform the dataset to be used to pull data from, the batch_size (number of samples in each batch) and a collate function that specifies how to “concatenate” the multiple samples into one single batch.
Finally, we can iterate through the dataloader to grab batches from it. To test it, we will get the first batch.
dataloader = DataLoader(dset_train, sampler=sampler, batch_size=8, collate_fn=stack_samples)
batch = next(iter(dataloader))batch.keys()
output:
dict_keys(['image', 'crs', 'bbox', 'mask'])
Batch Visualization
Now that we can draw batches from our datasets, let’s create a function to display the batches.
The function plot_batch will check automatically the number of items in the batch and if there are masks associated to arrange the output grid accordingly.
Now, plotting our batch:
plot_batch(batch)
Data Normalization (Standardization)
Normally, machine learning methods (deep learning included) benefit from feature scaling. That means standard deviation around 1 and zero mean, by applying the following formula (actually normalization is different from standardization, but I will leave the explanation to the reader: https://www.naukri.com/learning/articles/normalization-and-standardization/):
To do that, we need to first find the mean and standard deviation for each one of the 6s channels in the dataset.
Let’s define a function calculate these statistics and write its results in the variables mean and std. We will use our previously installed rasterio package to open the images and perform a simple average over the statistics for each batch/channel. For the standard deviation, this method is an approximation. For a more precise calculation, please refer to: http://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.htm.
Here we have 6 values in each list. Now we have to use these values to normalize the values every time a batch is created by the dataloader and passed to the trainer. Additionally, if we want to visualize this batch we need to “revert” the standardization, otherwise the true color will not be correct. We will then create a class that will do the trick. We are going to inherit it from the torch.nn.Module class, and define the forward method and also the revert method to “undo” the normalization.
Once the class is defined, we can instantiate it with the mean and std values obtained from our dataset and test the forward pass and the revert pass (code output has been suppressed).
normalize = MyNormalize(mean=mean, stdev=std)
norm_batch = normalize(batch)
plot_batch(norm_batch)
batch = normalize.revert(norm_batch)
plot_batch(batch)
Spectral Indices
To improve the performance of our neural net, we will perform some feature engineering and give as input other spectral indices, such as NDWI (Normalized Difference Water Index), MNDWI (Modified NDWI) and NDVI (Normalized Difference Vegetation Index).
TorchGeo has made it easier for us to append indices to our original dataset. For that, we will use the transformation module, like so:
from torchgeo.transforms import indices
ndwi_transform = indices.AppendNDWI(index_green=1, index_nir=3)
transformed_batch = ndwi_transform(batch)
print(transformed_batch['image'].shape, transformed_batch['mask'].shape)
code output:
torch.Size([8, 7, 512, 512]), torch.Size([8, 1, 512, 512])
Note that instead of 6 channels, we have now 7 channels in the image because we appended the NDWI indice.
We can now combine all the transformation we want into a Sequential object. Pay attention that we will put the normalization as the last transformation, because the spectral indices are supposed to work with raw reflectances directly.
tfms = torch.nn.Sequential(
indices.AppendNDWI(index_green=1, index_nir=3),
indices.AppendNDWI(index_green=1, index_nir=5),
indices.AppendNDVI(index_nir=3, index_red=2),
normalize
)
new_transformed_batch = tfms(batch)
print(batch['image'].shape, batch['mask'].shape)
code output:
torch.Size([8, 10, 512, 512]), torch.Size([8, 1, 512, 512])
Now we have a quite easy way to apply the transformations we want to our original data.
Important: the normalize method we created will apply the normalization just to the original bands and it will ignore the previously appended indices. That’s important to avoid errors due to distinct shapes between the batch and the mean and std vectors.
Notebook
I will leave here the full Colab notebook for this story:
Conclusion
In today’s story we’ve seen how to create the IntersectionDataset by combining images and masks from the original dataset. Additionally, we saw how to use the nn.Sequential to append transformations to the original data. These could also be used to add augmentations, but that is a more advanced topic that will not be covered in this series.
In the next story we are going to create the training loop, loss function and check the results of our newly created deep neural network. So, if you are curious, stay tuned and don’t forget to follow us.
See you in the next story.