Retrain a classification model on-device with backpropagation

If you're familiar with backpropagation, then you know it's used to train a neural network by updating the weights in every layer after you determine the model's current loss. However, you can also use backpropagation to update weights for only the last layer, which allows you to retrain your model very quickly. And it's this technique that our SoftmaxRegression API provides so you can accelerate transfer-learning with the Edge TPU.

Note: The SoftmaxRegression API is available in both PyCoral (Python) and Libcoral (C++), but this guide describes only the Python API.


Ordinarily, because a TensorFlow Lite model must be compiled to run on the Edge TPU, the weights inside the neural network are locked and cannot be modified by training on the device. However, if you remove the last layer from the model before compiling it (thus creating an embedding extractor model that outputs an image embedding), then you can implement the last layer on the device in a way that allows for retraining of that layer. So that's what we do to enable transfer-learning with SoftmaxRegression.

The SoftmaxRegression class is an on-device implementation of the fully-connected layer with softmax activation that performs final classification. And with its APIs, you can train the weights of the layer using stochastic gradient descent (SGD), immediately run inferences using the new weights, and save it as a new .tflite model file.

Of course, this strategy has both benefits and drawbacks:


  • Transfer-learning happens on-device, at near-realtime speed.
  • You don't need to recompile the model.


  • The fully-connected layer with softmax activation executes on the CPU, not the Edge TPU. However, this layer represents a very small portion of the overall network, so impact on the inference speed is minimal.
  • It's compatible with image classification models only—officially, only MobileNet and Inception.
Note: We offer an alternative on-device transfer-learning API called ImprintingEngine, which uses weight imprinting instead of backpropagation to update the weights of the last layer. For a comparison of these two techniques, read Transfer-learning on-device.

API summary

The SoftmaxRegression class represents only the softmax layer for a classification model. Unlike the ImprintingEngine, it does not encapsulate the entire model graph. So in order to perform training, you must run training data through the base model (the embedding extractor) and then feed the results to this softmax layer.

The basic procedure to train using backpropagation with the SoftmaxRegression API is as follows:

  1. Create an instance of Interpreter for the Edge TPU, such as with make_interpreter() and specifying your embedding extractor model.

  2. For each training image, call run an inference with the Interpreter and collect the output (which is the image embedding).

  3. Create an instance of SoftmaxRegression and call train_with_sgd(), passing it all the image embeddings. This is where the new training happens.

  4. Save the retrained model using serialize_model(), passing it the embedding extractor model. For example:

    with open(output_path, 'wb') as f:
  5. Then use the new model to run inferences with PyCoral and TensorFlow Lite.

See the next section for a walkthrough with our example code.

Retrain a model with our sample code

To better illustrate how you can use the SoftmaxRegression API, we've created a sample script: Follow the below procedure to try it with a flowers dataset.

If you're using the Dev Board, execute these commands on the board's terminal; if you're using the USB Accelerator, be sure it's connected to the host computer where you'll run these commands.

  1. Set up the directory where you'll save all your work:

    mkdir -p $DEMO_DIR
  2. Download and extract the flowers dataset:

    tar zxf flower_photos.tgz -C $DEMO_DIR
  3. Download our embedding extractor (a version of the neural network without the final fully-connected layer, and pre-trained on ImageNet):

    wget -P $DEMO_DIR

    We've also created embedding extractors for all three sizes of EfficientNet. You can find the links on our Models page. For example, here's the medium size:

    wget -P $DEMO_DIR

    If you want to use your own model, see the section below about how to create your own embedding extractor.

  4. Download and navigate to the sample code:

    mkdir coral && cd coral
    git clone
    cd pycoral/examples/
  5. Start transfer learning on the Edge TPU:

    python3 \
    --data_dir ${DEMO_DIR}/flower_photos \
    --embedding_extractor_path \
      ${DEMO_DIR}/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite \
    --output_dir ${DEMO_DIR}

    This takes 1 - 2 minutes, and you should see training logs printed to the console.

  6. Try the retrained model works by running it through the script:

    # Download a rose image from Open Images:
    curl -o ${DEMO_DIR}/rose.jpg
    python3 \ --model ${DEMO_DIR}/retrained_model_edgetpu.tflite \ --label ${DEMO_DIR}/label_map.txt \ --input ${DEMO_DIR}/rose.jpg

    You should see results such as this:

    --------------------------- roses Score : 0.99609375

Create an embedding extractor

To use this backpropagation technique with your own model, you need to compile your TensorFlow Lite model with its last layer removed. Doing so creates a model called an embedding extractor, which outputs an image embedding (also called a feature embedding tensor).

Separating the embedding extractor allows for the last fully-connected layer to be implemented on-device (with SoftmaxRegression) so we can backpropagate new weights. Assuming you've already trained a classification model with the supported model architectures, you can follow the steps below to create an embedding extractor from that pre-trained model.

Note: This procedure only works for models built using quantization-aware training. If you want to create an embedding extractor for a model created with post-training quantization, then you need to specify the output tensors when you quantize the model (which is not documented here).

Also, this technique does not work for EfficientNet, but you can instead use the EfficientNet tools to create the embedding extractor, or download a pre-compiled version from our Models page.

  1. Identify the feature embedding tensor. A feature embedding tensor is the input tensor for the last fully-connected layer. For the classification model architectures we officially support, the following table lists their feature embedding tensor names, and the feature dimensions.

    Model name Feature embedding tensor name Size
    mobilenet_v1_1.0_224_quant MobilenetV1/Logits/AvgPool_1a/AvgPool 1024
    mobilenet_v2_1.0_224_quant MobilenetV2/Logits/AvgPool 1280
    inception_v1_224_quant InceptionV1/Logits/AvgPool_0a_7x7/AvgPool 1024
    inception_v2_224_quant InceptionV2/Logits/AvgPool_1a_7x7/AvgPool 1024
    inception_v3_224_quant InceptionV3/Logits/AvgPool_1a_8x8/AvgPool 2048
    inception_v4_224_quant InceptionV4/Logits/AvgPool_1a/AvgPool 1536

    (You can also find the feature embedding tensor name when you visualize your model or list all the layers of your model using tools such as tflite_convert.)

  2. Cut off the last fully-connected layer from the pre-trained classification model. Because you'll be changing the weights in the last fully-connected layer, your embedding extractor model is just a new version of the existing model but with this last layer removed. So you'll remove this layer using the tflite_convert tool, which converts the TensorFlow frozen graph into the TensorFlow Lite format. You just need to specify the output array that is the input for the last fully-connected layer (the feature embedding tensor).

    For example, the following command extracts the embedding extractor from a MobileNet v1 model, and saves it as a TensorFlow Lite model.

    # Create embedding extractor from MobileNet v1 classification model
    tflite_convert \
    --output_file=mobilenet_v1_embedding_extractor.tflite \
    --graph_def_file=mobilenet_v1_1.0_224_quant_frozen.pb \
    --inference_type=QUANTIZED_UINT8 \
    --mean_values=128 \
    --std_dev_values=128 \
    --input_arrays=input \
  3. Compile the embedding extractor. You now have a version of the embedding extractor that's compiled for a CPU, so you now need to recompile it for the Edge TPU, using the Edge TPU Compiler. (This is no different than compiling a full classification model.)

Now just follow the procedures described in the API summary to perform training, or pass your model to the demo script.