Disentangled representation learning in cardiac image analysis
Graphical abstract
Introduction
Learning good data representations is a long running goal of machine learning (Bengio et al., 2013a). In general, representations are considered “good” if they capture explanatory (discriminative) factors of the data, and are useful for the task(s) being considered. Learning good data representations for medical imaging tasks poses additional challenges, since the representation must lend itself to a range of medically useful tasks, and work across data from various image modalities.
Within deep learning research there has recently been a renewed focus on methods for learning so called “disentangled” representations, for example in Higgins et al. (2017) and Chen et al. (2016). A disentangled representation is one in which information is represented as a number of (independent) factors, with each factor corresponding to some meaningful aspect of the data (Bengio et al., 2013a) (hence why sometimes encountered as factorised representations).
Disentangled representations offer many benefits: For example, they ensure the preservation of information not directly related to the primary task, which would otherwise be discarded, whilst they also facilitate the use of only the relevant aspects of the data as input to later tasks. Furthermore, and importantly, they improve the interpretability of the learned features, since each factor captures a distinct attribute of the data, while also varying independently from the other factors.
Disentangled representations have considerable potential in the analysis of medical data. In this paper we combine recent developments in disentangled representation learning with strong prior knowledge about medical image data: that it necessarily decomposes into an “anatomy factor” and a “modality factor”.
An anatomy factor that is explicitly spatial (represented as a multi-class semantic map) can maintain pixel-level correspondences with the input, and directly supports spatially equivariant tasks such as segmentation and registration. Most importantly, it also allows a meaningful representation of the anatomy that can be generalised to any modality. As we demonstrate below, a spatial anatomical representation is useful for various modality independent tasks, for example in extracting segmentations as well as in calculating cardiac functional indices. It also provides a suitable format for pooling information from various imaging modalities.
The non-spatial modality factor captures global image modality information, specifying how the anatomy is rendered in the final image. Maintaining a representation of the modality characteristics allows, among other things, the ability to use data from different modalities.
Finally, the ability to learn this factorisation using a very limited number of labels is of considerable significance in medical image analysis, as labelling data is tedious and costly. Thus, it will be demonstrated that the proposed factorisation, in addition to being intuitive and interpretable, also leads to considerable performance improvements in segmentation tasks when using a very limited number of labelled images.
Learning a decomposition of data into a spatial content factor and a non-spatial style factor has been a focus of recent research in computer vision (Huang, Liu, Belongie, Kautz, 2018, Lee, Tseng, Huang, Singh, Yang, 2018) with the aim being to achieve diversity in style transfer between domains. However, no consideration has been taken regarding the semantics and the precision of the spatial factor. This is crucial in medical analysis tasks in order to be able to extract quantifiable information directly from the spatial factor. Concurrently with these approaches, Chartsias et al. (2018) aimed to precisely address the need for interpretable semantics by explicitly enforcing the spatial factor to be a binary myocardial segmentation. However, since the spatial factor is a segmentation mask of only the myocardium, remaining anatomies must be encoded in the non-spatial factor, which violates the concept of explicit factorisation into anatomical and modality factors.
In this paper instead, we propose Spatial Decomposition Network (SDNet), schematic shown in Fig. 1, that learns a disentangled representation of medical images consisting of a spatial map that semantically represents the anatomy, and a non-spatial latent vector containing image modality information.
The anatomy is modelled as a multi-channel feature map, where each channel represents different anatomical substructures (e.g. myocardium, left and right ventricles). This spatial representation is categorical with each pixel necessarily belonging to exactly one channel. This strong restriction prevents the binary maps from encoding modality information, encouraging the anatomy factors to be modality-agnostic (invariant), and further promotes factorisation of the subject’s anatomy into meaningful topological regions.
On the other hand, the non-spatial factor contains modality-specific information, in particular the distribution of intensities of the spatial regions. We encode the image intensities into a smooth latent space, using a Variational Autoencoder (VAE) loss, such that nearby values in this space correspond to neighbouring values in the intensity space.
Finally, since the representation should retain most of the required information about the input (albeit in two factors), image reconstructions are possible by combining both factors.
In the literature the term “factor” usually refers to either a single dimension of a latent representation, or a meaningful aspect of the data (i.e. a group of dimensions) that can vary independently from other aspects. Here we use factor in the second sense, and we thus learn a representation that consists of a (multi-dimensional) anatomy factor, and a (multi-dimensional) modality factor. Although the individual dimensions of the factors could be seen as (sub-)factors themselves, for clarity we will refer to them as dimensions throughout the paper.
Our main contributions are as follows:
- •
With the use of few segmentation labels and a reconstruction cost, we learn a multi-channel spatial representation of the anatomy. We specifically restrict this representation to be semantically meaningful by imposing that it is a discrete categorical variable, such that different channels represent different anatomical regions.
- •
We learn a modality representation using a VAE, which allows sampling in the modality space. This facilitates the decomposition, permits latent space arithmetic, and also allows us to use part of our network as a generative model to synthesise new images.
- •
We detail design choices, such as using Feature-wise Linear Modulation (FiLM) (Perez et al., 2018) in the decoder, to ensure that the modality factors do not contain anatomical information, and prevent posterior collapse of the VAE.
- •
We demonstrate our method in a multi-class segmentation task, and on different datasets, and show that we maintain a good performance even when training with labelled images from only a single subject.
- •
We show that our semantic anatomical representation is useful for other anatomical tasks, such as inferring the Left Ventricular Volume (LVV). More critically, we show that we can also learn from such auxiliary tasks demonstrating the benefits of multi-task learning, whilst also improving the learned representation.
- •
Finally, we demonstrate that our method is suitable for multimodal learning (here multimodal refers to multiple modalities and not multiple modes in a statistical sense), where a single encoder is used with both MR and CT data, and show that information from additional modalities improves segmentation accuracy.
In this paper we advance our preliminary work (Chartsias et al., 2018) in the following aspects: 1) We learn a general anatomical representation useful for multi-task learning; 2) we perform multi-class segmentation (of multiple cardiac substructures); 3) we impose a structure in the imaging factor which follows a multi-dimensional Gaussian distribution, that allows sampling and improves generalisation; 4) we formulate the reconstruction process to use FiLM normalisation (Perez et al., 2018), instead of concatenating the two factors; and 5) we offer a series of experiments using four different datasets to show the capabilities and expressiveness of our representation.
The rest of the paper is organised as follows: Section 2 reviews related literature in representation learning and segmentation. Then, Section 3 describes our proposed approach. Sections 4 and 5 describe the setup and results of the experiments performed. Finally, Section 6 concludes the manuscript.
Section snippets
Related work
Here we review previous work on disentangled representation learning, which is typically a focus of research on generative models (Section 2.1). We then review its application in domain adaptation, which is achieved by a factorisation of style and content (Section 2.2). Finally, we review semi-supervised methods in medical imaging, as well as recent literature in cardiac segmentation, since they are related to the application domain of our method (Sections 2.3 and 2.4).
Materials and methods
Overall, our proposed model can be considered as an autoencoder, which takes as input a 2D volume slice x ∈ X, where X ⊂ IRH × W × 1 is the set of all images in the data, with H and W being the image’s height and width respectively. The model generates a reconstruction through an intermediate disentangled representation. The disentangled representation is comprised of a multi-channel spatial map (a tensor) s ∈ S ≔ {0, 1}H × W × C, where C is the number of channels, and a multi-dimensional
Data
In our experiments we use 2D images from four datasets, which have been normalised to the range [-1, 1].
- (a)
For the semi-supervised segmentation experiment (Section 5.1) and the latent space arithmetic (Section 5.5) we use data from the 2017 Automatic Cardiac Diagnosis Challenge (ACDC) (Bernard et al., 2018). This dataset contains cine-MR images acquired in 1.5T and 3T MR scanners, with resolution between 1.22 and 1.68 mm2/pixel and a number of phases varying between 28 to 40 images per patient. We
Results and discussion
We here present and discuss quantitative and qualitative results of our method in various experimental scenarios. Initially, multi-class semi-supervised segmentation is evaluated in Section 5.1. Subsequently, Section 5.2 demonstrates multi-task learning with the addition of a regression task in the training objectives. In Section 5.3.1, SDNet is evaluated in a multimodal scenario by concurrently segmenting MR and CT data. In Section 5.4 we investigate whether the modality factor z captures
Conclusion
We have presented a method for disentangling medical images into a spatial and a non-spatial latent factor, where we enforced a semantically meaningful spatial factor of the anatomy and a non-spatial factor encoding the modality information. To the best of our knowledge, maintaining semantics in the spatial factor has not been previously investigated. Moreover, through the incorporation of a variational autoencoder, we can treat our method as a generative model, which allows us to also
Conflict of interest
We wish to confirm that there are no known conflicts of interest associated with this publication entitled “Disentangled Representation Learning in Cardiac Image Analysis” and there has been no significant financial support for this work that could have influenced its outcome.
Acknowledgements
This work was supported in part by the US National Institutes of Health (1R01HL136578-01) and UK EPSRC (EP/P022928/1), and used resources provided by the Edinburgh Compute and Data Facility (http://www.ecdf.ed.ac.uk/). S.A. Tsaftaris acknowledges the support of the Royal Academy of Engineering and the Research Chairs and Senior Research Fellowships scheme.
References (50)
- et al.
Automated cardiovascular magnetic resonance image analysis with fully convolutional networks
J. Cardiovascular Magn. Resonance
(2018) - Chollet, F., et al., 2015. Keras....
- et al.
Ω-Net (omega-net): fully automatic, multi-view cardiac mr detection, orientation, and segmentation with deep neural networks
Med. Image Anal.
(2018) - et al.
Deep learning based instance segmentation in 3D biomedical images using weak annotation
Medical Image Computing and Computer Assisted Intervention
(2018) - et al.
A registration-based propagation framework for automatic whole heart segmentation of cardiac MRI
IEEE Trans. Med. Imag.
(2010) - et al.
Augmented CycleGAN: Learning many-to-many mappings from unpaired data
International Conference on Machine Learning
(2018) - et al.
Multi-content GAN for few-shot font style transfer
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
(2018) - et al.
Semi-supervised learning for network-based cardiac MR image segmentation
Medical Image Computing and Computer-Assisted Intervention
(2017) - et al.
Recurrent neural networks for aortic image sequence segmentation with sparse annotations
- et al.
Representation learning: a review and new perspectives
IEEE Trans. Pattern Anal. Mach.Intell.
(2013)
Estimating or propagating gradients through stochastic neurons for conditional computation.
Deep learning techniques for automatic MRI cardiac multi-structures segmentation and diagnosis: is the problem solved?
IEEE Trans. Med. Imag.
Dictionary-driven ischemia detection from cardiac phase-resolved myocardial BOLD MRI at rest
IEEE Trans. Med. Imag.
Learning interpretable anatomical features through deep generative models: Application to cardiac remodeling
Understanding disentangling in β-vae
NIPS Workshop on Learning Disentangled Representations
Adversarial image synthesis for unpaired multi-modal cardiac data
Simulation and Synthesis in Medical Imaging
Factorised spatial representation learning: application in semi-supervised myocardial segmentation
Medical Image Computing and Computer Assisted Intervention
InfoGAN: interpretable representation learning by information maximizing generative adversarial nets
Advances in neural information processing systems
Not-so-supervised: A survey of semi-supervised, multi-instance, and transfer learning in medical image analysis
Med. Image Anal_
Semantically decomposing the latent spaces of generative adversarial networks
International Conference on Learning Representations
A variational u-net for conditional appearance and shape generation
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
Scalable multimodal convolutional networks for brain tumour segmentation
Medical Image Computing and Computer-Assisted Intervention
Image style transfer using convolutional neural networks
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
Generative adversarial nets
Advances in neural information processing systems
beta-vae: learning basic visual concepts with a constrained variational framework
International Conference on Learning Representations
Cited by (169)
Light-M: An efficient lightweight medical image segmentation framework for resource-constrained IoMT
2024, Computers in Biology and MedicineFeature-based domain disentanglement and randomization: A generalized framework for rail surface defect segmentation in unseen scenarios
2024, Advanced Engineering InformaticsLearning with limited target data to detect cells in cross-modality images
2023, Medical Image Analysis