edgetpu.learn.imprinting.engine
A weight imprinting engine that performs low-shot transfer-learning for image classification models.
For more information about how to use this API and how to create the type of model required, see Retrain a classification model on-device with weight imprinting.
Note
We updated ImprintingEngine
in the July 2019 library update (version
2.11.1), which requires code changes if you used the previous version. The API changes are
as follows:
- Most importantly, the input model has new architecture requirements. For details, read Retrain a classification model on-device with weight imprinting.
- The initialization function accepts a new
keep_classes
boolean to indicate whether you want to keep the pre-trained classes from the provided model. train()
now requires a second argument for the class ID you want to train, thus allowing you to retrain classes with additional data. (It no longer returns the class ID.)train_all()
requires a different format for the input data. It now uses a list in which each index corresponds to a class ID, and each list entry is an array of training images for that class. (It no longer returns a mapping of label IDs.)- New methods
classify_with_resized_image()
andclassify_with_input_tensor()
allow you to immediately perform inferences, though you can still choose to save the trained model as a.tflite
file withsave_model()
.
-
class
edgetpu.learn.imprinting.engine.
ImprintingEngine
(model_path, keep_classes=False)¶ Performs weight imprinting (transfer learning) with the given model.
Parameters: - model_path (str) – Path to the model you want to retrain. This model must be a
.tflite
file output by thejoin_tflite_models
tool. For more information about how to create a compatible model, read Retrain an image classification model on-device. - keep_classes (bool) – If True, keep the existing classes from the pre-trained model (and use training to add additional classes). If False, drop the existing classes and train the model to include new classes only.
-
classify_with_input_tensor
(input_tensor, threshold=0.0, top_k=3)¶ Performs classification with the retrained model using the given raw input tensor.
This requires you to process the input data (the image) and convert it to the appropriately formatted input tensor for your model.
Parameters: - input_tensor (
numpy.ndarray
) – A 1-D array as the input tensor. - threshold (float) – Minimum confidence threshold for returned classifications. For example,
use
0.5
to receive only classifications with a confidence equal-to or higher-than 0.5. - top_k (int) – The maximum number of classifications to return.
Returns: A
list
of classifications, each of which is a list [int, float] that represents the label id (int) and the confidence score (float).Raises: ValueError
– If argument values are invalid.- input_tensor (
-
classify_with_resized_image
(img, threshold=0.1, top_k=3)¶ Performs classification with the retrained model using the given image.
Note: The given image must already be resized to match the model’s input tensor size.
Parameters: Returns: A
list
of classifications, each of which is a list [int, float] that represents the label id (int) and the confidence score (float).Raises: ValueError
– If argument values are invalid.
-
save_model
(output_path)¶ Saves the newly trained model as a
.tflite
file.You can then use the saved model to perform inferencing with using
ClassificationEngine
. Alternatively, you can immediately perform inferences with the retrained model using the local inferencing methods,classify_with_resized_image()
orclassify_with_input_tensor()
.Parameters: output_path (str) – The path and filename where you’d like to save the trained model (must end with .tflite
).
-
train
(input, class_id)¶ Trains the model with a set of images for one class.
You can use this to add new classes to the model or retrain classes that you previously added using this imprinting API.
Parameters: - input (list of
numpy.array
) – The images to use for training in a single class. Eachnumpy.array
in the list represents an image as a 1-D tensor. You can convert each image to this format by passing it as anPIL.Image
tonumpy.asarray()
. The maximum number of images allowed in the list is 200. - class_id (int) – The label id for this class. The index must be either the number of existing classes (to add a new class to the model) or the index of an existing class that was trained using this imprinting API (you can’t retrain classes from the pre-trained model).
- input (list of
-
train_all
(input_data)¶ Trains the model with multiple sets of images for multiple classes.
This essentially calls
train()
for each class of images you provide. You can use this to add a batch of new classes or retrain existing classes. Just beware that if you’ve already added new classes using the imprinting API, then the data input here must include the same classes in the same order. Alternatively, you can usetrain()
to retrain specific classes one at a time.Parameters: input_data (list of numpy.array
) – The images to train for multiple classes. Eachnumpy.array
in the list represents a different class, which itself contains a list ofnumpy.array
objects, which each represent an image as a 1-D tensor. You can convert each image to this format by passing it as aPIL.Image
tonumpy.asarray()
.
- model_path (str) – Path to the model you want to retrain. This model must be a
Is this content helpful?