pycoral.learn.imprinting

pycoral.learn.imprinting.engine

A weight imprinting engine that performs low-shot transfer-learning for image classification models.

For more information about how to use this API and how to create the type of model required, see Retrain a classification model on-device with weight imprinting.

class pycoral.learn.imprinting.engine.ImprintingEngine(model_path, keep_classes=False)

Performs weight imprinting (transfer learning) with the given model.

Parameters
  • model_path (str) – Path to the model you want to retrain. This model must be a .tflite file output by the join_tflite_models tool. For more information about how to create a compatible model, read Retrain an image classification model on-device.

  • keep_classes (bool) – If True, keep the existing classes from the pre-trained model (and use training to add additional classes). If False, drop the existing classes and train the model to include new classes only.

property embedding_dim

Returns number of embedding dimensions.

property num_classes

Returns number of currently trained classes.

serialize_extractor_model()

Returns embedding extractor model as bytes object.

serialize_model()

Returns newly trained model as bytes object.

train(embedding, class_id)

Trains the model with the given embedding for specified class.

You can use this to add new classes to the model or retrain classes that you previously added using this imprinting API.

Parameters
  • embedding (numpy.array) – The embedding vector for training specified single class.

  • class_id (int) – The label id for this class. The index must be either the number of existing classes (to add a new class to the model) or the index of an existing class that was trained using this imprinting API (you can’t retrain classes from the pre-trained model).