pycoral.learn.imprinting
pycoral.learn.imprinting.engine
A weight imprinting engine that performs low-shot transfer-learning for image classification models.
For more information about how to use this API and how to create the type of model required, see Retrain a classification model on-device with weight imprinting.
-
class
pycoral.learn.imprinting.engine.
ImprintingEngine
(model_path, keep_classes=False)[source]¶ Performs weight imprinting (transfer learning) with the given model.
- Parameters
model_path (str) – Path to the
.tflite
model you want to retrain. This must be a model that’s specially-designed for this API. You can use our weight imprinting model that has a pre-trained base model, or you can train the base model yourself by following our guide to Retrain the base MobileNet model.keep_classes (bool) – If True, keep the existing classes from the pre-trained model (and use training to add additional classes). If False, drop the existing classes and train the model to include new classes only.
-
property
embedding_dim
¶ Returns number of embedding dimensions.
-
property
num_classes
¶ Returns number of currently trained classes.
-
train
(embedding, class_id)[source]¶ Trains the model with the given embedding for specified class.
You can use this to add new classes to the model or retrain classes that you previously added using this imprinting API.
- Parameters
embedding (
numpy.array
) – The embedding vector for training specified single class.class_id (int) – The label id for this class. The index must be either the number of existing classes (to add a new class to the model) or the index of an existing class that was trained using this imprinting API (you can’t retrain classes from the pre-trained model).
API version 2.0
Is this content helpful?