Health

Microbe Classification Using Deep Learning

advertisement

Part 2: Training

First, we import the necessary modules.

from fastai.tabular import *
from fastai.vision import *
path = r"C:bacteriadata"

My folder structure: Parent folder (C:bacteriadata)→ sub-folders with class names (C:bacteriadatacandida) → images

np.random.seed(42)
data = ImageDataBunch.from_folder(path, valid_pct=0.2,
ds_tfms=get_transforms(), size=224, num_workers=4, bs=32).normalize(imagenet_stats)
data.classes, data.c, len(data.train_ds), len(data.valid_ds)

We have 12 different classes with 10,234 training images and 2,558 validation images.

advertisement
learn = cnn_learner(data, models.resnet50, metrics=accuracy).to_fp16()

Fasti.ai supports mixed precision training and it’s as simple as adding .to_fp16() when building the learner. For those with NVIDIA RTX graphic cards, mixed precision greatly speeds up training and halves the memory requirement. Eric has an excellent article on this and you can read it here: https://towardsdatascience.com/rtx-2060-vs-gtx-1080ti-in-deep-learning-gpu-benchmarks-cheapest-rtx-vs-most-expensive-gtx-card-cd47cd9931d2

learn.fit_one_cycle(4)

The results are already looking pretty good even without training the top layers. We will now unfreeze the top layers and find a suitable learning rate.

learn.unfreeze()
learn.lr_find()
learn.recorder.plot()

I trained the entire model further with a reduced learning rate.

learn.fit_one_cycle(10, max_lr=slice(7e-5, 9e-4))

At the end of 10 epochs, we achieved an accuracy of 99.6%. There is no significant over-fitting as the training and validation loss are similar. Let’s take a look at the confusion matrix.

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(10,10), dpi=100)

The main errors seem to be arising from proteus vs. pseudomonas and E.coli vs. proteus. This is understandable as all 3 species are gram-negative rods which stain and appear similar under the light microscope. Let’s take a look at the images with top losses.

interp.plot_top_losses(6, figsize=(15,15))

Tags
Show More

Related Articles

Leave a Reply

Your email address will not be published. Required fields are marked *

Back to top button

Adblock Detected

Please consider supporting us by disabling your ad blocker