A weight imprinting engine that performs low-shot transfer-learning for image classification models.
For more information about how to use this API and how to create the type of model required, see Retrain a classification model on-device with weight imprinting.
Performs weight imprinting (transfer learning) with the given model.
model_path (str) – Path to the model you want to retrain. This model must be a
.tflitefile output by the
join_tflite_modelstool. For more information about how to create a compatible model, read Retrain an image classification model on-device.
keep_classes (bool) – If True, keep the existing classes from the pre-trained model (and use training to add additional classes). If False, drop the existing classes and train the model to include new classes only.
Returns number of embedding dimensions.
Returns number of currently trained classes.
Returns embedding extractor model as bytes object.
Returns newly trained model as bytes object.
Trains the model with the given embedding for specified class.
You can use this to add new classes to the model or retrain classes that you previously added using this imprinting API.
numpy.array) – The embedding vector for training specified single class.
class_id (int) – The label id for this class. The index must be either the number of existing classes (to add a new class to the model) or the index of an existing class that was trained using this imprinting API (you can’t retrain classes from the pre-trained model).
API version 1.0
Is this content helpful?