A softmax regression model for on-device backpropagation of the last layer.

class pycoral.learn.backprop.softmax_regression.SoftmaxRegression(feature_dim=None, num_classes=None, weight_scale=0.01, reg=0.0)

An implementation of the softmax regression function (multinominal logistic

regression) that operates as the last layer of your classification model, and allows for on-device training with backpropagation (for this layer only).

The input for this layer must be an image embedding, which should be the output of your embedding extractor (the backbone of your model). Once given here, the input is fed to a fully-connected layer where weights and bias are applied, and then passed to the softmax function to receive the final probability distribution based on the number of classes for your model:

training/inference input (image embedding) –> fully-connected layer –> softmax function

When you’re conducting training with train_with_sgd(), the process uses a cross-entropy loss function to measure the error and then update the weights of the fully-connected layer (backpropagation).

When you’re satisfied with the inference accuracy, call serialize_model() to create a new model in bytes with this retrained layer appended to your embedding extractor. You can then run inferences with this new model as usual (using TensorFlow Lite interpreter API).


This last layer (FC + softmax) in the retrained model always runs on the host CPU instead of the Edge TPU. As long as the rest of your embedding extractor model is compiled for the Edge TPU, then running this last layer on the CPU should not significantly affect the inference speed.

For more detail, see the Stanford CS231 explanation of the softmax classifier.

  • feature_dim (int) – The dimension of the input feature (length of the feature vector).

  • num_classes (int) – The number of output classes.

  • weight_scale (float) – A weight factor for computing new weights. The backpropagated weights are drawn from standard normal distribution, then multiplied by this number to keep the scale small.

  • reg (float) – The regularization strength.

get_accuracy(mat_x, labels)

Calculates the model’s accuracy (percentage correct).

The calculation is on performing inferences on the given data and labels.

  • mat_x (numpy.array) – The input data (image embeddings) to test, as a matrix of shape NxD, where N is number of inputs to test and D is the dimension of the input feature (length of the feature vector).

  • labels (numpy.array) – An array of the correct label indices that correspond to the test data passed in mat_x (class label index in one-hot vector).


The accuracy (the percent correct) as a float.


Appends learned weights to your TensorFlow Lite model and serializes it.

Beware that learned weights and biases are quantized from float32 to uint8.


in_model_path (str) – Path to the embedding extractor model (.tflite file).


The TF Lite model with new weights, as a bytes object.

train_with_sgd(data, num_iter, learning_rate, batch_size=100, print_every=100)

Trains your model using stochastic gradient descent (SGD).

The training data must be structured in a dictionary as specified in the data argument below. Notably, the training/validation images must be passed as image embeddings, not as the original image input. That is, run the images through your embedding extractor (the backbone of your graph) and use the resulting image embeddings here.

  • data (dict) – A dictionary that maps 'data_train' to an array of training image embeddings, 'labels_train' to an array of training labels, 'data_val' to an array of validation image embeddings, and 'labels_val' to an array of validation labels.

  • num_iter (int) – The number of iterations to train.

  • learning_rate (float) – The learning rate (step size) to use in training.

  • batch_size (int) – The number of training examples to use in each iteration.

  • print_every (int) – The number of iterations for which to print the loss, and training/validation accuracy. For example, 20 prints the stats for every 20 iterations. 0 disables printing.

API version 1.0