Retrain an image classification model

This tutorial shows you how to retrain an image classification model to recognize a new set of classes. You'll use a technique called transfer learning to retrain an existing model and then compile it to run on any device with an Edge TPU, such as the Coral Dev Board or USB Accelerator.

Specifically, this tutorial shows you how to perform quantization-aware training (using TensorFlow 1.15) with the MobileNet V1 model so it can recognize different types of flowers (adopted from TensorFlow's docs).

Note that this tutorial runs the training scripts on your computer using a Docker virtual environment, so the training time (and even the ability to complete the training) depends on your system specs. As an alternative, we also offer retraining tutorials that run in the cloud, using Google Colab:

What is transfer learning?

Ordinarily, training an image classification model can take many hours on a CPU, but transfer learning is a technique that takes a model already trained for a related task and uses it as the starting point to create a new model. Depending on your system and training parameters, this instead takes less than an hour. (This process is sometimes also called "fine-tuning" the model.)

Transfer learning can be done in two ways:

  • Last layers-only retraining: This approach retrains only the last few layers of the model, where the final classification occurs. This is fast and it can be done with a small dataset.
  • Full model retraining: This approach retrains each layer of the neural network using the new dataset. It can result in a model that is more accurate, but it takes more time, and you must retrain using a dataset of significant sample size to avoid overfitting the model.

Transfer learning is most effective when the features learned in the pre-trained model are general, not highly specialized. For example, a pre-trained model that can recognize household objects might be re-trained to recognize new office supplies, but a model pre-trained to recognize different dog breeds might not.

The steps below show you how to perform transfer-learning using either last-layers-only or full-model retraining. Most of the steps are the same—just keep an eye out for the different commands depending on the technique you desire.

These instructions do not require deep experience with TensorFlow or convolutional neural networks (CNNs), but such experience will definitely help you build a more accurate model. This tutorial also does not teach you how to design and organize a dataset, or tune the hyperparameters to converge your model to the highest possible accuracy. For any of that, refer to other literature about deep learning strategies.

Total time required to complete this tutorial is about 1 hour. But if you're experienced with TensorFlow and you retrain only the last few layers, you can finish in about 30 minutes.

Requirements

You need the following for this tutorial:

  • Any computer supported by Docker (such as Linux, Mac, or Windows).
  • At least 4 GB of RAM.
  • A device with an Edge TPU, such as the Coral Dev Board or USB Accelerator (these each have their own list of requirements).
Note: This tutorial is designed to run training on a desktop CPU—not on a GPU or in the cloud, which requires changes beyond the scope of this tutorial. You also should not try this on the Coral Dev Board due to CPU and memory constraints—this training cannot be accelerated by the Edge TPU.

Set up the Docker container

Docker is a virtualization platform that makes it easy to set up an isolated environment for this tutorial. Using our Docker container, you can easily set up the required environment, which includes TensorFlow, Python, classification scripts, and the pre-trained checkpoints for MobileNet V1 and V2.

To set up your container, follow these steps:

  1. First install Docker on your desktop machine (this link is for Ubuntu; select your appropriate platform from the Docker left navigation).

  2. Open a command line and create a directory for all the files in this project. You will clone the Coral tutorials repo into it, so name it accordingly. For example:

    CORAL_DIR=${HOME}/google-coral && mkdir -p ${CORAL_DIR}
    
  3. Move into that directory and clone our tutorials repo, which has all the training scripts:

    cd ${CORAL_DIR}
    
    git clone https://github.com/google-coral/tutorials.git
  4. Move into the directory for this tutorial and build the Docker image:

    cd tutorials/docker/classification
    
    docker build . -t classify-tutorial-tf1
  5. Specify the location for the training output files. For example:

    CLASSIFY_DIR=${PWD}/out && mkdir -p $CLASSIFY_DIR
    

    You'll use this as the mount location for a directory in the Docker container, thus saving the training files and final model to your file system (instead of leaving them inside the Docker container).

  6. Start the Docker container:

    docker run --name edgetpu-classify \
    --rm -it --privileged -p 6006:6006 \
    --mount type=bind,src=${CLASSIFY_DIR},dst=/tensorflow/models/research/slim/transfer_learn \
    classify-tutorial-tf1
    

When that's finished, your command prompt should be inside the Docker container and in the path /tensorflow/models/research/slim/.

You're ready to start training your model.

Prepare your dataset

In this tutorial, you'll create a flower classifier. So before you begin training, you need to download the flowers dataset and convert it to the TFRecord format. We've prepared the following script (in the slim/ directory) to take care of that for you:

# Run this from within the Docker container (at tensorflow/models/research/slim/):
./prepare_checkpoint_and_dataset.sh --network_type mobilenet_v1

The network_type can be one of the following: mobilenet_v1, mobilenet_v2, inception_v1, inception_v2, inception_v3, or inception_v4. If you decide to try one of these other model architectures, be sure you use the same model name in the other commands where it's used below.

Retrain your classification model

You can perform transfer-learning to retrain just the last few layers of a model, or you can retrain the whole model. However, beware that if you have limited training data, retraining the whole model can lead to overfitting, so you should instead retrain just the last layers. We'll show both methods below.

  1. Start transfer-learning in one of the following ways:

    • If you want to retrain only the last few layers of the model, use the following command:

      ./start_training.sh --network_type mobilenet_v1
      
    • If you want to retrain the whole model, use this command:

      ./start_training.sh --network_type mobilenet_v1 --train_whole_model true
      

    It might take a 1 - 2 minutes for the training pipeline to start. Once training begins, the terminal will continuously print progress of the training, with lines like this:

    INFO:tensorflow:Recording summary at step 42.
    INFO:tensorflow:global step 60: loss = 1.1883 (1.347 sec/step)
    INFO:tensorflow:global step 80: loss = 0.8204 (1.363 sec/step)
    

    Depending on your machine and the model architecture (MobileNet generally trains a lot faster than Inception), it can take 10 - 30 minutes to train the last few layers with 300 steps for MobileNet V1 (based on 16 core CPU and 60G memory). Training the whole model will take longer.

  2. To monitor training progress, start tensorboard in a new terminal:

    1. Start bash in a separate terminal to join the same Docker container.

      sudo docker exec -it edgetpu-classify /bin/bash
      
    2. In the new Docker terminal, execute the following command to start tensorboard. After you execute the command, tensorboard visualizes the model accuracy throughout training in your local machine's browser at http://localhost:6006/.

      # From the Docker /tensorflow/models/research/slim/ directory
      tensorboard --logdir=./transfer_learn/train/
      
  3. To evaluate the performance using the latest checkpoint, use the run_evaluation.sh script.

    If your training is still in process, you can still run the script, but you need to open another new terminal as follows:

    sudo docker exec -it edgetpu-classify /bin/bash
    

    (Or just wait until training completes.) Then run the evaluation script:

    # From the Docker /tensorflow/models/research/slim/ directory
    ./run_evaluation.sh --network_type mobilenet_v1
    

    After some various output, you'll see the accuracy printed like this:

    eval/Accuracy[0.7175]eval/Recall_5[1]
    

If the accuracy does not satisfy you, open the start_training.sh file and tweak some parameters passed to the train_image_classifier.py script, and then retrain again (step 1). (You'll need to first remove the ./transfer_learn/train/ directory, which contains the previously trained files.)

Compile the model for the Edge TPU

To run your retrained model on the Edge TPU, you need to convert the new checkpoint file to a frozen graph, convert that to a TensorFlow Lite flatbuffer file, then compile the model for the Edge TPU. We've provided a script to simplify some of this for you, which you can run as follows.

  1. To freeze the graph and convert it to TensorFlow Lite, use the following script and specify the checkpoint number you want to use (this example uses checkpoint 300):

    # From the Docker /tensorflow/models/research/slim/ directory
    ./convert_checkpoint_to_edgetpu_tflite.sh --network_type mobilenet_v1 --checkpoint_num 300
    

    Your converted TensorFlow Lite model is named output_tflite_graph.tflite and is output in the Docker container at tensorflow/models/research/slim/transfer_learn/models/, which is the mounted directory available on your host filesystem at $CLASSIFY_DIR.

  2. Open a new terminal (outside the Docker container) and install the Edge TPU Compiler:

    curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
    
    echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list
    sudo apt update
    sudo apt-get install edgetpu-compiler
  3. Make sure your user has ownership of the out directory:

    sudo chown -R $USER ${HOME}/google-coral/tutorials/docker/classification/out
    
  4. Now change directories to where the trained model is and compile it:

    cd ${HOME}/google-coral/tutorials/docker/classification/out/models
    
    edgetpu_compiler output_tflite_graph.tflite

    The compiled file is named output_tflite_graph_edgetpu.tflite and saved in the current directory.

  5. Finally, rename the compiled model to something more specific:

    mv output_tflite_graph_edgetpu.tflite mobilenet_v1_flowers_quant_edgetpu.tflite
    

Run the model

You can now use the retrained model to run an inference on the Edge TPU. Below, you can see how to use this model with the classify_image.py example, which performs image classification using the TensorFlow Lite Python API.

Remember that you've trained this model to recognize just five flower classes: daisy, dandelion, roses, sunflowers, and tulips. Here are a couple images you can try (provided by the Open Images Dataset):

wget https://c4.staticflickr.com/3/2856/13169252123_e4c5086ea3_z.jpg -O flower.jpg && \
wget https://c6.staticflickr.com/4/3372/3416475881_726f0d33fe_z.jpg -O flower2.jpg

Using the Coral Dev Board

  1. First, be sure your Dev Board software is up to date.

  2. Use MDT to push the files to the Dev Board and switch to the Dev Board shell:

    mdt push mobilenet_v1_flowers_quant_edgetpu.tflite labels.txt flower.jpg
    
    mdt shell
  3. Now from the Dev Board shell, download the classify_image.py code from GitHub:

    mdt shell
    
    mkdir google-coral && cd google-coral
    git clone https://github.com/google-coral/tflite --depth 1
  4. Install the example's requirements:

    cd tflite/python/examples/classification
    
    ./install_requirements.sh
  5. Run the example using the files you pushed in step 2:

    python3 classify_image.py \
      --model ${HOME}/mobilenet_v1_flowers_quant_edgetpu.tflite \
      --labels ${HOME}/labels.txt \
      --input ${HOME}/flower.jpg
    

Using the Coral USB Accelerator

  1. First, be sure your USB Accelerator is set up.

  2. Although this is also part of the device setup, here's how to get the classify_image.py code from GitHub:

    mkdir google-coral && cd google-coral
    
    git clone https://github.com/google-coral/tflite --depth 1
  3. Install the project requirements:

    cd tflite/python/examples/classification
    
    ./install_requirements.sh
  4. Run the example using the retrained model:

    python3 classify_image.py \
      --model ${HOME}/google-coral/tutorials/docker/classification/out/models/mobilenet_v1_flowers_quant_edgetpu.tflite \
      --labels ${HOME}/google-coral/tutorials/docker/classification/out/models/labels.txt \
      --input flower.jpg