Retrain a classification model on-device with weight imprinting

Weight imprinting is a technique for retraining a neural network (classification models only) using a small set of sample data, based on the technique described in Low-Shot Learning with Imprinted Weights. It's designed to update the weights for only the last layer of the model, but in a way that can retain existing classes while adding new ones. We've implemented this technique in the ImprintingEngine API, allowing you to accelerate transfer-learning with the Edge TPU.


To use the ImprintingEngine API, you need to provide a specially-designed model that separates the embedding extractor from the last layer where classification occurs. This is necessary because once a model is compiled for the Edge TPU, the network's weights are locked and cannot be changed—by separating the last layer and compiling only the base of the graph, we can update weights in the classification layer. Additionally, the weight imprinting technique requires a few changes to the model architecture to facilitate more accurate weights for the last layer (such as an additional L2-normalization layer and an added scaling factor). (For all the details about the model architecture, read Low-Shot Learning with Imprinted Weights.)

However, unlike the on-device backpropagation technique, the model you provide for weight imprinting must be the complete graph (not just the embedding extractor). The model is still divided into separate parts for the embedding extractor and the classification layer, and only the base portion is compiled for the Edge TPU, but the two parts are recombined so that the classifications from the original model are preserved. However, the original classes cannot be retrained—you can train and update only new classes that you add using the ImprintingEngine API.

Of course, this strategy has both benefits and drawbacks:


  • Transfer-learning happens on-device, at near-realtime speed.
  • Very few sample images are required (fewer than 10 training samples can achieve high accuracy).
  • You don't need to recompile the model.


  • It has difficulty learning from datasets with large intra-class variation (when the data for a given class contains large variation across samples, such as major differences in the subject angle or size). If your use-case expects data with high intra-class variance, consider instead using on-device transfer learning with backpropagation (it requires a larger training dataset).
  • The last fully-connected layer executes on the CPU, not the Edge TPU. However, this layer represents a very small portion of the overall network, so impact on the inference speed is minimal.
  • It has specific model architecture requirements. We've shared a version of MobileNet v1 that is compatible (see below), but if you prefer a different model, then you must make the necessary changes to your model.
Note: Another way you can perform on-device transfer-learning is with the SoftmaxRegression API, which instead uses backpropagation to update the weights of the last layer. For a comparison of these two techniques, read Transfer-learning on-device.

API summary

The ImprintingEngine class encapsulates the entire model that you want to train. Once you instantiate an instance with a compatible model, you can pass it training data to update the weights in the last layer, and then immediately use the model to perform inferencing.

The basic procedure to perform weight imprinting with the ImprintingEngine API is as follows:

  1. Create an instance of ImprintingEngine by specifying a compatible TensorFlow Lite model. Most applications should use our pre-trained model (mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite), but you can also retrain the MobileNet model or build your own model.

    The initialization function allows you to specify whether to keep the classifications from the pre-trained model or abandon them and use only the classes you're about to add.

  2. Create an instance of Interpreter for the Edge TPU, using the ImprintingEngine model, provided by serialize_extractor_model(). For example:

    engine = ImprintingEngine(model_path)
    extractor_interpreter = make_edgetpu_interpreter(engine.serialize_extractor_model())
  3. For each training image, run an inference with the Interpreter and collect the output (which is the image embedding).

  4. Then train a new class or continue training an existing class by calling train(), which takes the image embedding and a label ID.

  5. Save the retrained model using serialize_model(). For example:

    with open(output_path, 'wb') as f:
  6. Then use the new model to run inferences with PyCoral and TensorFlow Lite.

See the next section for a walkthrough with our example code.

Retrain a model on-device with our sample

To show you how this works, we've created a sample script,, which uses ImprintingEngine to perform on-device transfer learning with a given model.

The model you'll retrain with this sample is a modified MobileNet v1 model that's pre-trained to understand 1,000 classes from the ImageNet dataset. The ImprintingEngine API allows you to keep the original classes learned from pre-training, but in this sample, you'll abandon those and retrain it to understand just 10 classes (the model retains all the feature extractors from the base model—we only reset the final classifications).

If you're using the Dev Board, execute the following commands on the board's terminal. If you're using the USB Accelerator, be sure it's connected to the host computer where you'll run these commands.

  1. Set up the directory where you'll save all your work:

    mkdir -p $DEMO_DIR
  2. Download our pre-trained model (a custom version of MobileNet v1):

    wget  -P $DEMO_DIR

    This model is pre-trained with 1,000 classes from ImageNet, so the base model has very good feature extractors. But if you want to pre-train this model with your own dataset, see the section below about how to retrain the base MobileNet model (this is not the usual procedure to train MobileNet).

  3. Download our sample training dataset (10 classes with about 20 photos each):

    wget  -P $DEMO_DIR
    tar zxf $DEMO_DIR/imprinting_data_script.tar.gz -C $DEMO_DIR
    bash $DEMO_DIR/imprinting_data_script/ $DEMO_DIR

    This takes a couple minutes to download the images (depending on your internet speed).

  4. Download and navigate to the sample code that performs retraining:

    mkdir coral && cd coral
    git clone
    cd pycoral/examples/
  5. Start transfer learning on the Edge TPU:

    python3 \
    --model_path ${DEMO_DIR}/mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite \
    --data ${DEMO_DIR}/open_image_v4_subset \
    --output ${DEMO_DIR}/retrained_imprinting_model.tflite

    This should take 1 - 2 minutes when using our sample dataset. When it's done, the newly trained model is saved at ${DEMO_DIR}/retrained_imprinting_model.tflite.

  6. Try the retrained model by running it through the script:

    # Download a new cat photo from Open Images: curl -o ${DEMO_DIR}/cat.jpg
    python3 \ --model ${DEMO_DIR}/retrained_imprinting_model.tflite \ --label ${DEMO_DIR}/retrained_imprinting_model.txt \ --image ${DEMO_DIR}/cat.jpg

    You should see results such as this:

    --------------------------- Cat Score : 0.9921875

That's it! You've just trained a model with weight imprinting on the Edge TPU.

To repeat this demo with your own dataset, just add a new directory inside the open_image_v4_subset directory, and add some photos of a new class (even just 5 - 10 photos should work). Then repeat steps 5 and 6 to retrain the model and perform an inference.

Retrain the base MobileNet model

Caution: This section is intended for advanced TensorFlow users who want to customize the base model. The process requires multiple build steps and the model training may require several hours, even when using a powerful GPU. Depending on your environment, completing this process might require different or additional steps compared to those provided below, so familiarity with the TensorFlow work environment is necessary.

The MobileNet model we shared for the above demo was trained with 1,000 classes from ImageNet ILSVRC2012, which results in a model with very good feature extractors for a variety of image classification tasks. However, if you want to fine-tune the base MobileNet model with your own training dataset, you can do so as follows.

And although the training above was accelerated by the Edge TPU, the following retraining for the base MobileNet model cannot run on the Edge TPU, and some required tools are not compatible with the Coral Dev Board. So you should perform these steps on a powerful desktop computer.


Pre-train the model

  1. Sync our Git repo that contains the training scripts:

    git clone
    cd imprinting-training
    git submodule init && git submodule update
    export PYTHONPATH=$(pwd):$(pwd)/models/research/slim
  2. Build our modified MobileNet v1 model with L2-normalization:

    cd classification
    bazel build mobilenet_v1_l2norm
  3. Start the training script with the model checkpoint and dataset (set the variables for your own data paths):

    # Location of your TFRecord files DATASET_DIR=/home/edgetpu/classify/flowers
    # Location of your checkpoint (the common path for all .ckpt files) FINETUNE_CHECKPOINT_PATH=/home/edgetpu/classify/train/model.ckpt-300
    # Destination for the training logs CHECKPOINT_DIR=/home/edgetpu/l2norm-training
    python3 \ --quantize=True \ --dataset_dir=${DATASET_DIR} \ --fine_tune_checkpoint=${FINETUNE_CHECKPOINT_PATH} \ --checkpoint_dir=${CHECKPOINT_DIR} \ --freeze_base_model=True \ --number_of_steps=100000
  4. When you're ready to evaluate the model performance, you can do so as follows:

    python3 \ --quantize=True \ --checkpoint_dir=$CHECKPOINT_FILE \ --dataset_dir=$DATASET_DIR

Now you have a pre-trained MobileNet model (with L2-normalization).

The next section shows how to convert the model for the Edge TPU.

Export the graph for the Edge TPU

  1. Save a GraphDef of the model:

    # Still inside the imprinting-training/classification/ directory python3 \ --quantize=True \ --output_file=$CHECKPOINT_DIR/mobilenet_v1_l2norm_inf_graph.pb
  2. Freeze the graph with your new checkpoint:

    # Use the tensorflow repo inside imprinting-training/
    cd ../tensorflow/
    bazel build tensorflow/python/tools:freeze_graph

    # Check the output of the build command for the freeze_graph location freeze_graph \ --input_graph=$CHECKPOINT_DIR/mobilenet_v1_l2norm_inf_graph.pb \ --input_checkpoint=$CHECKPOINT_FILE \ --input_binary=true \ --output_graph=$CHECKPOINT_DIR/frozen_mobilenet_v1_l2norm.pb \ --output_node_names=MobilenetV1/Predictions/Reshape_1
  3. Now we need to strip out the L2-norm operator that we added in the base graph because this operation is not supported on the Edge TPU (removing it now has no effect because the new weights are already frozen how we want them):

    bazel build tensorflow/tools/graph_transforms:transform_graph
    transform_graph \ --in_graph=$CHECKPOINT_DIR/frozen_mobilenet_v1_l2norm.pb \ --out_graph=$CHECKPOINT_DIR/frozen_mobilenet_v1_l2norm_optimized.pb \ --inputs=input \ --outputs=MobilenetV1/Predictions/Reshape_1 \ --transforms='strip_unused_nodes fold_constants'
  4. Now run the following commands to separate both the model base (the embedding extractor) and the model head (the classification layer) as individual graphs.

    1. First convert the entire frozen graph to a TensorFlow Lite file (or else the model head will have the wrong input parameters):

      tflite_convert \
      --graph_def_file=$CHECKPOINT_DIR/frozen_mobilenet_v1_l2norm_optimized.pb \
      --output_file=$CHECKPOINT_DIR/mobilenet_v1_l2norm_quant.tflite \
      --inference_type=QUANTIZED_UINT8 \
      --mean_values=128 \
      --std_dev_values=128 \
      --input_arrays=input \
    2. Create the base graph as its own file (using toco because tflite_convert does not support .tflite files as input):

      # You should build the following version of toco because the
      # packaged version does not support the 'input_file' argument bazel build tensorflow/lite/toco:toco
      toco \ --input_file=$CHECKPOINT_DIR/mobilenet_v1_l2norm_quant.tflite \ --output_file=$CHECKPOINT_DIR/mobilenet_v1_embedding_extractor.tflite \ --input_format=TFLITE \ --output_format=TFLITE \ --inference_type=QUANTIZED_UINT8 \ --input_arrays=input \ --output_arrays=MobilenetV1/Logits/AvgPool_1a/AvgPool
    3. Create the head graph:

      toco \
      --input_file=$CHECKPOINT_DIR/mobilenet_v1_l2norm_quant.tflite \
      --output_file=$CHECKPOINT_DIR/mobilenet_v1_last_layers.tflite \
      --input_format=TFLITE \
      --output_format=TFLITE \
      --inference_type=QUANTIZED_UINT8 \
      --input_arrays=MobilenetV1/Logits/AvgPool_1a/AvgPool \
  5. Compile the base graph with the Edge TPU Compiler:

    edgetpu_compiler -o $CHECKPOINT_DIR \
  6. Then re-join the compiled base graph to the head graph using our join_tflite_models tool.

    First clone our code repo and build the tool from within Docker:

    git clone
    # This step might take a few minutes: make docker-shell

    # Set the CPU variable and build based on your system architecture...
    # Either: CPU=k8 make tools
    # Or: CPU=armv7a make tools
    # Or: CPU=aarch64 make tools

    # Then exit Docker exit

    Now run the join tool (using the path in out/ based on what you built above):

    # Your prompt should be in the edgetpu/ root:
    ./out/k8/tools/join_tflite_models \
    --input_graph_base=$CHECKPOINT_DIR/mobilenet_v1_embedding_extractor.tflite \
    --input_graph_head=$CHECKPOINT_DIR/mobilenet_v1_last_layers.tflite \

Now you're done.

You can move the mobilenet_v1_l2norm_quant_edgetpu.tflite file to your Edge TPU device and use it with the sample script above or your own code using the ImprintingEngine API.

Build a different model for ImprintingEngine

Everything above uses a version of the MobileNet v1 model we created specifically for weight imprinting, because ImprintingEngine is not compatible with ordinary classification models or embedding extractors.

Creating a different classification model that's compatible with ImprintingEngine is possible, but it's a significant undertaking that demands expert TensorFlow knowledge.

A complete description of the architecture we implemented for our model is beyond the scope of this document, but you can inspect our implementation in

We also suggest you carefully read the research paper that this design is based upon: Low-Shot Learning with Imprinted Weights.