How to build an AI that detects pneumonia
Artificial intelligence is used in more and more areas of medicine. One of the most current applications: an AI that helps detect pneumonia from X-ray images. Based on this use case, I will show you what is important when developing an Applied Medical AI using Convolutional Neural Networks (CNNs).
by Dzmitry Ashkinadze
Due to recent developments in the field of deep neural networks (NN) and convolutional neural networks (CNNs), it has become much easier to develop machine learning (ML) algorithms that have applied medical significance using open-source software solutions. Healthcare institutions therefore increasingly implement AI-driven solutions: according to the Grand View Research, market capitalization of artificial intelligence in healthcare in the US was valued at USD 10.4 billion in 2021 with an expected annual growth rate of 38.4% from 2022 to 2030.
Which AI applications are paying off in healthcare?
Accenture has analysed the ten AI applications with the greatest near-term impact on healthcare. The three most important applications in the near future are robot-assisted surgery, virtual nursing assistants and applications to support administrative processes.
ML systems for pneumonia detection are highly relevant because they can assist with and expedite diagnosis, thereby facilitating earlier treatment. Put simply, they consist of convolutional neural networks that are trained from the dataset of manually labeled X-ray images (Fig.1) with each X-ray being labeled as healthy or unhealthy (pneumonia). Let’s look at how exactly such a ML system is developed and what we can learn from this use case for our own Applied Medical AI projects.
Fig. 1: X-ray images of a healthy person (above) and a person with pneumonia (below).
Data and how to prevent class bias
Like any other machine learning project, this one too starts with data, namely with the publicly available dataset “OCT and Chest X-Ray images and codes”. The dataset consists of 5863 labeled X-ray images collected from children between 1 and 5 years old from the Woman and Children’s Medical Center in Guangzhou, China.
To prepare the data for the ML algorithm, the developers first gray scaled the input images and downsized them to 150x150 pixels for standardization and denoising. It was also important to make sure that the “healthy” and “unhealthy” classes used for the classification were balanced (meaning equal in size). Otherwise, they would potentially bias the model to predict a dominant class. Since deep learning algorithms perform best with large datasets, the developers used augmentation to both balance the classes and enrich the dataset by randomly rotating, zooming and shifting the images. Finally, according to the best machine learning practice, the data was split into a training and a test section to avoid overfitting (i.e. to avoid that the noise or random fluctuations in the training data would be picked up and learned as concepts by the network).
Model construction: defining loss function and limiting overfitting
The first step towards the construction of the CNN is a proper definition of the loss function. The loss function compares the predictions of the CNN model with ground truth. The closer the model predictions are to the ground truth, the lower the loss. Therefore, we can optimize and train the model by minimizing the loss function. The optimal type of loss function depends on the machine learning task. For this use case, the cross-entropy loss was used as it is typically used for classification tasks.
The convolutional neural network was constructed sequentially out of five sets of convolution and max-pooling layers. The first set of layers in CNN extracts weak local features that help to differentiate X-ray chest images of healthy people from those of unhealthy people by looking at minute details. A typical example of a local feature extraction is “edge detection”: it identifies high-contrast areas in the image. With each subsequent set of layers, those local features are combined to construct strong global features. For example, global features might be the presence or absence of unusual cloudy formations in the lungs. Those global features are finally used as input for the conventional fully connected feedforward neural network (FNN) that classifies the images in corresponding classes.
One problem that can occur with CNN is overfitting: in this use case, it was limited by the dropout layers that randomly disconnect nodes in the neural network. The developers used rectified linear activation function (ReLU) throughout the neural network (except the output) as it is easier to train compared to the alternative activation functions and allows for deeper formations of the neural networks. It is important to understand that there is no neural network architecture that always works best; it needs to be optimized for each new project.
Speaking of AI & Health
As a breast cancer survivor and co-founder of the Hippo AI Foundation, Viktoria Prantauer is committed to ensuring that everyone can benefit from medical data. Read more now!
With an AI-based technology, the Swiss start-up Scailyte manages to detect diseases before it is too late. Listen to the podcast now!
Model training with mini batches
The training process of such deep CNNs is typically time- and compute-intensive and requires a lot of high-quality data. However, once the model is trained, it can be used to quickly classify the X-ray images as “healthy” and “unhealthy”.
To reduce the network training time, mini batches of 32 images each were used. As you can see in Figure 2, the training and validation accuracies (left) and losses (right) level off more or less after 20 epochs. This means that the CNN model training was successful, and it converged to an optimal solution. It also means that this model can predict whether a training image is “healthy” or “unhealthy” with 98% accuracy. And most importantly, it can predict whether a novel X-ray image is “healthy” or “unhealthy” with approximately 80% accuracy. A RMSprop optimization algorithm was used as a network optimizer together with an adaptive learning rate to improve performance and avoid convergence to any suboptimal solution.
Fig. 2: Model accuracy (left) and cross-entropy loss during training.
Model evaluation: the importance of accuracy and recall
To evaluate and compare the classification models, it is important to discuss model accuracy and model recall. Model accuracy shows how accurately the model identifies the pneumonia, whereas model recall shows how many pneumonia cases are identified by the model. If the model is very “cautious” and identifies pneumonia only for highly obvious cases, it is highly accurate, but it has low recall. On the other hand, if the model is “generous” with identification of pneumonia even by slight suspicion, this model has excellent recall, but low accuracy.
The developers of our use case model calculated the data about accuracy and recall by comparing the test labels and labels predicted from the test X-ray images. The results are summarized in a confusion matrix that shows how many healthy and pneumonia cases respectively (Fig. 3 vertical axis) are labeled as such by the CNN (Fig. 3 horizontal axis). Depending on the nature of the use case, accuracy and recall importance must be specified prior to the CNN design. In the case of a medical tool, recall is very important as it can have severe consequences to label a person with pneumonia as healthy.
Fig. 3: Confusion matrix
When building an Applied Medical AI, for example one that detects pneumonia on the basis of X-ray images, you should pay attention to the following points:
- Make sure that the classes used for the classification are about equal in size, so that they do not bias the model to predict a dominant class.
- In terms of function loss, since the primary goal is to construct the classification model, use the cross-entropy loss that is minimized during model training.
- Limit the overfitting in CNN with the help of dropout layers that randomly disconnect nodes in the neural network.
- To evaluate your model, look at its accuracy and recall. For medical tools, a high recall is very important since a case that is falsely identified as “healthy” can have severe consequences for a patient.
As I explained previously, an increasing number of healthcare institutions implement AI-driven solutions to assist them, for example with early disease detection. The COVID-19 pandemic has further accelerated the adoption of medical AI technology. AI-driven solutions for the identification of pneumonia, like the solution in our use case, are already used in hospitals including UC San Diego.