pycoral.learn.backprop
pycoral.learn.backprop.softmax_regression
A softmax regression model for on-device backpropagation of the last layer.
-
class
pycoral.learn.backprop.softmax_regression.
SoftmaxRegression
(feature_dim=None, num_classes=None, weight_scale=0.01, reg=0.0)¶ An implementation of the softmax regression function (multinominal logistic
regression) that operates as the last layer of your classification model, and allows for on-device training with backpropagation (for this layer only).
The input for this layer must be an image embedding, which should be the output of your embedding extractor (the backbone of your model). Once given here, the input is fed to a fully-connected layer where weights and bias are applied, and then passed to the softmax function to receive the final probability distribution based on the number of classes for your model:
training/inference input (image embedding) –> fully-connected layer –> softmax function
When you’re conducting training with
train_with_sgd()
, the process uses a cross-entropy loss function to measure the error and then update the weights of the fully-connected layer (backpropagation).When you’re satisfied with the inference accuracy, call
serialize_model()
to create a new model in bytes with this retrained layer appended to your embedding extractor. You can then run inferences with this new model as usual (using TensorFlow Lite interpreter API).Note
This last layer (FC + softmax) in the retrained model always runs on the host CPU instead of the Edge TPU. As long as the rest of your embedding extractor model is compiled for the Edge TPU, then running this last layer on the CPU should not significantly affect the inference speed.
For more detail, see the Stanford CS231 explanation of the softmax classifier.
- Parameters
feature_dim (int) – The dimension of the input feature (length of the feature vector).
num_classes (int) – The number of output classes.
weight_scale (float) – A weight factor for computing new weights. The backpropagated weights are drawn from standard normal distribution, then multiplied by this number to keep the scale small.
reg (float) – The regularization strength.
-
get_accuracy
(mat_x, labels)¶ Calculates the model’s accuracy (percentage correct).
The calculation is on performing inferences on the given data and labels.
- Parameters
mat_x (
numpy.array
) – The input data (image embeddings) to test, as a matrix of shapeNxD
, whereN
is number of inputs to test andD
is the dimension of the input feature (length of the feature vector).labels (
numpy.array
) – An array of the correct label indices that correspond to the test data passed inmat_x
(class label index in one-hot vector).
- Returns
The accuracy (the percent correct) as a float.
-
serialize_model
(in_model_path)¶ Appends learned weights to your TensorFlow Lite model and serializes it.
Beware that learned weights and biases are quantized from float32 to uint8.
- Parameters
in_model_path (str) – Path to the embedding extractor model (
.tflite
file).- Returns
The TF Lite model with new weights, as a bytes object.
-
train_with_sgd
(data, num_iter, learning_rate, batch_size=100, print_every=100)¶ Trains your model using stochastic gradient descent (SGD).
The training data must be structured in a dictionary as specified in the
data
argument below. Notably, the training/validation images must be passed as image embeddings, not as the original image input. That is, run the images through your embedding extractor (the backbone of your graph) and use the resulting image embeddings here.- Parameters
data (dict) – A dictionary that maps
'data_train'
to an array of training image embeddings,'labels_train'
to an array of training labels,'data_val'
to an array of validation image embeddings, and'labels_val'
to an array of validation labels.num_iter (int) – The number of iterations to train.
learning_rate (float) – The learning rate (step size) to use in training.
batch_size (int) – The number of training examples to use in each iteration.
print_every (int) – The number of iterations for which to print the loss, and training/validation accuracy. For example,
20
prints the stats for every 20 iterations.0
disables printing.
API version 1.0
Is this content helpful?