TensorFlow Lite Micro APIs

The Coral Dev Board Micro allows you to run two types of TensorFlow models: TensorFlow Lite Micro models that run on entirely the microcontroller (MCU) and TensorFlow Lite models that are compiled for acceleration on the Coral Edge TPU. Although you can run TensorFlow Lite Micro models on either MCU core (M4 or M7), currently, you must execute Edge TPU models from the M7.

Note

If you have experience with TensorFlow Lite on other platforms (including other Coral boards/accelerators), a lot of the code to run inference on the Dev Board Micro should be familiar, but the APIs are actually different for microcontrollers, so your code is not 100% portable.

To run any TensorFlow Lite model on the Dev Board Micro, you must use the TensorFlow interpreter provided by TensorFlow Lite for Microcontrollers (TFLM): tflite::MicroInterpreter. If you’re running a model on the Edge TPU, the only difference compared to running a model on the MCU is that you need to specify the Edge TPU custom op when you instantiate the tflite::MicroInterpreter (and your model must be compiled for the Edge TPU). The following steps describe the basic procedures to run inference on the Dev Board Micro using either type of model.

First, you need to perform some setup:

  1. If using on the Edge TPU, power on the Edge TPU with OpenDevice():

    auto tpu_context = EdgeTpuManager::GetSingleton()->OpenDevice();
    if (!tpu_context) {
      printf("ERROR: Failed to get EdgeTpu context\r\n");
    }
    
  2. Load your .tflite model from a file into a byte array with LfsReadFile():

    constexpr char kModelPath[] =
        "/models/tf2_ssd_mobilenet_v2_coco17_ptq_edgetpu.tflite";
    std::vector<uint8_t> model;
    if (!LfsReadFile(kModelPath, &model)) {
      printf("ERROR: Failed to load %s\r\n", kModelPath);
    }
    

    Note: Some micro ML apps instead load their model from a C array that’s compiled with the app, which is also an option, but that’s intended for microcontrollers without a filesystem (Dev Board Micro has a littlefs filesystem).

  3. Specify each of the TensorFlow ops required by your model with a MicroMutableOpResolver. When using a model compiled for the Edge TPU, you must include the kCustomOp with AddCustom(). For example:

    tflite::MicroMutableOpResolver<3> resolver;
    resolver.AddDequantize();
    resolver.AddDetectionPostprocess();
    resolver.AddCustom(kCustomOp, RegisterCustomOp());
    
  4. Specify the memory arena required for your model’s input, output, and intermediate tensors. To ensure 16-bit alignment (required by TFLM) and avoid running out of heap space, you should use either the STATIC_TENSOR_ARENA_IN_SDRAM or STATIC_TENSOR_ARENA_IN_OCRAM macro to allocate your tensor arena:

    constexpr int kTensorArenaSize = 8 * 1024 * 1024;
    STATIC_TENSOR_ARENA_IN_SDRAM(tensor_arena, kTensorArenaSize);
    

    Selecting the best arena size depends on the model and requires some trial-and-error: Just start with a small number like 1024 and run it; TFLM will throw an error at runtime and tell you the size you actually need.

  5. Instantiate a MicroInterpreter, passing it your model, op resolver, tensor arena, and a MicroErrorReporter:

    tflite::MicroErrorReporter error_reporter;
    tflite::MicroInterpreter interpreter(tflite::GetModel(model.data()),
                                         resolver, tensor_arena,
                                         kTensorArenaSize, &error_reporter);
    
  6. Allocate all model tensors with AllocateTensors():

    if (interpreter.AllocateTensors() != kTfLiteOk) {
      printf("ERROR: AllocateTensors() failed\r\n");
    }
    

Now you’re ready to run each inference as follows:

  1. Get the allocated input tensor with input_tensor() and fill it with your input data. For example, if you’re using the Dev Board Micro camera, you can simply set the input tensor as the buffer for your CameraFrameFormat (see examples/detect_faces/). Or you can copy your input data using GetTensorData() and std::memcpy like this:

    auto* input_tensor = interpreter.input_tensor(0);
    std::memcpy(tflite::GetTensorData<uint8_t>(input_tensor), image.data(),
                image.size());
    
  2. Execute the model with Invoke():

    if (interpreter.Invoke() != kTfLiteOk) {
      printf("ERROR: Invoke() failed\r\n");
    }
    
  3. Similar to writing the input tensor, you can then read the output tensor with GetTensorData() by passing it output_tensor().

    However, instead of processing this output data yourself, you can use the APIs below that correspond to the type of model you’re running. For example, if you’re running an object detection model, instead of reading the output tensor directly, call GetDetectionResults() and pass it your MicroInterpreter. This function returns an Object for each detected object, which specifies the detected object’s label id, prediction score, and bounding-box coordinates:

    auto results = tensorflow::GetDetectionResults(&interpreter, 0.6, 3);
    printf("%s\r\n", tensorflow::FormatDetectionOutput(results).c_str());
    

See the following documentation for more code examples, each of which is included from the coralmicro examples, which you can browse at coralmicro/examples/.

Also check out the TensorFlow Lite for Microcontrollers documentation.

TFLM interpreter

This is just a small set of APIs from TensorFlow Lite for Microcontrollers (TFLM) that represent the core APIs you need to run inference on the Dev Board Micro. You can see the rest of the TFLM APIs in coralmicro/third_party/tflite-micro/.

Note

The version of TFLM included in coralmicro is not continuously updated, so some APIs might be different from the latest version of TFLM on GitHub.

For usage examples, see the following sections, such as for image classification.

class tflite::MicroInterpreter

Encapsulates a pre-trained model and drives the model inference.

Note

This class is not thread-safe. The client is responsible for ensuring serialized interaction to avoid data races and undefined behavior.

Public Functions

MicroInterpreter(const Model *model, const MicroOpResolver &op_resolver, uint8_t *tensor_arena, size_t tensor_arena_size, ErrorReporter *error_reporter, MicroResourceVariables *resource_variables = nullptr, MicroProfiler *profiler = nullptr)

Constructor. Creates an instance with an allocated tensor arena.

The lifetime of the model, op resolver, tensor arena, error reporter, resource variables, and profiler must be at least as long as that of the interpreter object, since the interpreter may need to access them at any time. This means that you should usually create them with the same scope as each other, for example having them all allocated on the stack as local variables through a top-level function. The interpreter doesn’t do any deallocation of any of the pointed-to objects, ownership remains with the caller.

Parameters
  • model – A trained TensorFlow Lite model.

  • op_resolver – The op resolver that contains all ops used by the model. This is usually an instance of tflite::MicroMutableOpResolver.

  • tensor_arena – The allocated memory for all intermediate tensor data.

  • tensor_arena_size – The size of tensor_arena.

  • error_reporter – Object to use for error reports.

  • resource_variables – Handles assign/read ops for resource variables. See Resource variable docs.

  • profiler – Handles profiling for op kernels and TFLM routines. See Profiling docs.

MicroInterpreter(const Model *model, const MicroOpResolver &op_resolver, MicroAllocator *allocator, ErrorReporter *error_reporter, MicroResourceVariables *resource_variables = nullptr, MicroProfiler *profiler = nullptr)

Constructor. Creates an instance using an existing MicroAllocator instance.

This constructor should be used when creating an allocator that needs to have allocation handled in more than one interpreter or for recording allocations inside the interpreter. The lifetime of the allocator must be as long as that of the interpreter object.

Parameters
  • model – A trained TensorFlow Lite model.

  • op_resolver – The op resolver that contains all ops used by the model. This is usually an instance of tflite::MicroMutableOpResolver.

  • allocator – The object that allocates all intermediate tensor data.

  • error_reporter – Object to use for error reports.

  • resource_variables – Handles assign/read ops for resource variables. See Resource variable docs.

  • profiler – Handles profiling for op kernels and TFLM routines. See Profiling docs.

TfLiteStatus AllocateTensors()

Allocates all the model’s necessary input, output and intermediate tensors.

This will redim dependent tensors using the input tensor dimensionality as given. This is relatively expensive. This must be called after the interpreter has been created and before running inference (and accessing tensor buffers), and must be called again if (and only if) an input tensor is resized.

Returns

Atatus of success or failure. Will fail if any of the ops in the model (other than those which were rewritten by delegates, if any) are not supported by the Interpreter’s OpResolver.

TfLiteStatus Invoke()

Invokes the model to run inference using allocated input tensors.

In order to support partial graph runs for strided models, this can return values other than kTfLiteOk and kTfLiteError.

TfLiteStatus SetMicroExternalContext(void *external_context_payload)

This is the recommended API for an application to pass an external payload pointer as an external context to kernels. The life time of the payload pointer should be at least as long as this interpreter. TFLM supports only one external context.

TfLiteTensor *input(size_t index)

Gets a mutable pointer to an input tensor.

Parameters

index – The index position of the input tensor. Must be between 0 and inputs_size().

Returns

The input tensor.

inline size_t inputs_size() const

Gets the size of the input tensors.

inline const flatbuffers::Vector<int32_t> &inputs() const

Gets a read-only list of all inputs.

inline TfLiteTensor *input_tensor(size_t index)

Same as input().

template<class T>
inline T *typed_input_tensor(int tensor_index)

Gets a mutable pointer into the data of a given input tensor.

The given index must be between 0 and inputs_size().

TfLiteTensor *output(size_t index)

Gets a mutable pointer to an output tensor.

Parameters

index – The index position of the output tensor. Must be between 0 and outputs_size().

Returns

The output tensor.

inline size_t outputs_size() const

Gets the size of the output tensors.

inline const flatbuffers::Vector<int32_t> &outputs() const

Gets a read-only list of all outputs.

inline TfLiteTensor *output_tensor(size_t index)

Same as output().

template<class T>
inline T *typed_output_tensor(int tensor_index)

Gets a mutable pointer into the data of a given output tensor.

The given index must be between 0 and outputs_size().

TfLiteStatus ResetVariableTensors()

Reset all variable tensors to the default value.

TfLiteStatus PrepareNodeAndRegistrationDataFromFlatbuffer()

Populates node and registration pointers representing the inference graph of the model from values inside the flatbuffer (loaded from the TfLiteModel instance). Persistent data (e.g. operator data) is allocated from the arena.

inline size_t arena_used_bytes() const

For debugging only. Returns the actual used arena in bytes. This method gives the optimal arena size. It’s only available after AllocateTensors has been called. Note that normally tensor_arena requires 16 bytes alignment to fully utilize the space. If it’s not the case, the optimial arena size would be arena_used_bytes() + 16.

inline const tflite::Model *tflite::GetModel(const void *buf)

Creates a Model object with the given model data, which you need for the tflite::MicroInterpreter constructor.

Parameters

buf – The model data, either loaded from a C array or from a .tflite file.

Returns

The model object to use with the tflite::MicroInterpreter constructor.

template<unsigned int tOpCount>
class tflite::MicroMutableOpResolver : public tflite::MicroOpResolver

Maps ops in the loaded model to executable functions on the device.

You must use this object to specify each of the ops required by your model (using the various Add... functions), and then pass this to the tflite::MicroInterpreter constructor.

Public Functions

inline TfLiteStatus AddCustom(const char *name, TfLiteRegistration *registration)

Registers a Custom Operator with the MicroOpResolver.

Only the first call for a given name will be successful. That is, if this function is called again for a previously added Custom Operator, the MicroOpResolver will be unchanged and this function will return kTfLiteError.

Parameters
  • name – Name of the custom op.

  • registration – Handler for the custom op.

Returns

kTfLiteOk if successful; kTfLiteError otherwise.

Note

The tflite::MicroMutableOpResolver has a long list of Add... functions to specify the ops that you need for your model. To see them all, refer to the micro_mutable_op_resolver.h source code.

class tflite::MicroErrorReporter : public ErrorReporter

Reports errors for MicroInterpreter.

Public Functions

inline ~MicroErrorReporter() override
int Report(const char *format, va_list args) override
template<typename T>
T *tflite::micro::GetTensorData(TfLiteEvalTensor *tensor)

Gets the mutable data for a specified tensor.

Parameters

tensor – The tensor to read/write.

Returns

A pointer to the tensor data.

Edge TPU runtime

Note

The Edge TPU is not available within M4 programs.

These APIs provide access to the Edge TPU on the Dev Board Micro. Anytime you want to use the Edge TPU for acceleration with MicroInterpreter, you need to do two things:

  1. Start the Edge TPU with OpenDevice().

  2. Register the Edge TPU custom op with your interpreter by passing kCustomOp and RegisterCustomOp() to tflite::MicroMutableOpResolver::AddCustom().

Example (from examples/classify_images_file/):

  auto tpu_context = EdgeTpuManager::GetSingleton()->OpenDevice();
  if (!tpu_context) {
    printf("ERROR: Failed to get EdgeTpu context\r\n");
    return;
  }

  tflite::MicroErrorReporter error_reporter;
  tflite::MicroMutableOpResolver<1> resolver;
  resolver.AddCustom(kCustomOp, RegisterCustomOp());

  tflite::MicroInterpreter interpreter(tflite::GetModel(model.data()), resolver,
                                       tensor_arena, kTensorArenaSize,
                                       &error_reporter);

Note

Unlike the libcoral C++ API, when using this coralmicro C++ API, you do not need to pass the EdgeTpuContext to the tflite::MicroInterpreter, but the context must be opened and the custom op must be registered before you create an interpreter. (This is different because libcoral is based on TensorFlow Lite and coralmicro is based on TensorFlow Lite for Microcontrollers.)

namespace coralmicro
class EdgeTpuContext
#include <edgetpu_manager.h>

This class is a representation of the Edge TPU device, so there is one shared EdgeTpuContext used by all model interpreters.

Instances of this should be allocated with EdgeTpuManager::OpenDevice().

The EdgeTpuContext can be shared among multiple software components, and the life of this object is directly tied to the Edge TPU power, so the Edge TPU powers down after the last EdgeTpuContext reference leaves scope.

The lifetime of the EdgeTpuContext must be longer than all associated tflite::MicroInterpreter instances.

class EdgeTpuManager
#include <edgetpu_manager.h>

Singleton Edge TPU manager for allocating new instances of EdgeTpuContext.

Public Functions

std::shared_ptr<EdgeTpuContext> OpenDevice(PerformanceMode mode = PerformanceMode::kHigh)

Gets the default Edge TPU device (and starts it if necessary).

The Edge TPU device (represented by EdgeTpuContext) can be shared among multiple software components, and the EdgeTpuManager is a singleton object, so you should always call this function like this:

auto tpu_context = EdgeTpuManager::GetSingleton()->OpenDevice();
Parameters

mode – The PerformanceMode to use for the Edge TPU. Options are: kMax (500Mhz), kHigh (250Mhz), kMedium (125Mhz), or kLow (63Mhz). If omitted, the default is kHigh. Caution: If you set the performance mode to kMax, it can increase the Edge TPU inferencing speed, but it can also make the Edge TPU module hotter, which might cause burns if touched.

Returns

A shared pointer to Edge TPU device. The shared_ptr can point to nullptr in case of error.

std::optional<float> GetTemperature()

Gets the current Edge TPU junction temperature.

Returns

The temperature in Celcius, or std::nullopt if EdgeTpuContext is empty.

Public Static Functions

static inline EdgeTpuManager *GetSingleton()

Gets a pointer to the EdgeTpuManager singleton object.

namespace coralmicro

Functions

TfLiteRegistration *RegisterCustomOp()

Returns pointer to an instance of tflite::TfLiteRegistration to handle Edge TPU custom ops. Pass this to tflite::MicroMutableOpResolver::AddCustom().

Variables

constexpr char kCustomOp[] = "edgetpu-custom-op"

Edge TPU custom op. Pass this to tflite::MicroMutableOpResolver::AddCustom().

Image classification

These APIs simplify the pre- and post-processing for image classification models.

Example (from examples/classify_images_file/):

namespace coralmicro {
namespace {
constexpr char kModelPath[] =
    "/models/mobilenet_v1_1.0_224_quant_edgetpu.tflite";
constexpr char kImagePath[] = "/examples/classify_images_file/cat_224x224.rgb";
constexpr int kTensorArenaSize = 1024 * 1024;
STATIC_TENSOR_ARENA_IN_SDRAM(tensor_arena, kTensorArenaSize);

void Main() {
  printf("Classify Image Example!\r\n");
  // Turn on Status LED to show the board is on.
  LedSet(Led::kStatus, true);

  std::vector<uint8_t> model;
  if (!LfsReadFile(kModelPath, &model)) {
    printf("ERROR: Failed to load %s\r\n", kModelPath);
    return;
  }

  // [start-sphinx-snippet:edgetpu]
  auto tpu_context = EdgeTpuManager::GetSingleton()->OpenDevice();
  if (!tpu_context) {
    printf("ERROR: Failed to get EdgeTpu context\r\n");
    return;
  }

  tflite::MicroErrorReporter error_reporter;
  tflite::MicroMutableOpResolver<1> resolver;
  resolver.AddCustom(kCustomOp, RegisterCustomOp());

  tflite::MicroInterpreter interpreter(tflite::GetModel(model.data()), resolver,
                                       tensor_arena, kTensorArenaSize,
                                       &error_reporter);
  // [end-sphinx-snippet:edgetpu]
  if (interpreter.AllocateTensors() != kTfLiteOk) {
    printf("ERROR: AllocateTensors() failed\r\n");
    return;
  }

  if (interpreter.inputs().size() != 1) {
    printf("ERROR: Model must have only one input tensor\r\n");
    return;
  }

  auto* input_tensor = interpreter.input_tensor(0);
  if (!LfsReadFile(kImagePath, tflite::GetTensorData<uint8_t>(input_tensor),
                   input_tensor->bytes)) {
    printf("ERROR: Failed to load %s\r\n", kImagePath);
    return;
  }

  if (interpreter.Invoke() != kTfLiteOk) {
    printf("ERROR: Invoke() failed\r\n");
    return;
  }

  auto results = tensorflow::GetClassificationResults(&interpreter, 0.0f, 3);
  for (auto& result : results)
    printf("Label ID: %d Score: %f\r\n", result.id, result.score);
}
}  // namespace
}  // namespace coralmicro

extern "C" void app_main(void* param) {
  (void)param;
  coralmicro::Main();
  vTaskSuspend(nullptr);
}
namespace coralmicro
namespace tensorflow
struct Class
#include <classification.h>

Represents a classification result.

Public Members

int id

The class label id.

float score

The prediction score.

Functions

std::string FormatClassificationOutput(const std::vector<tensorflow::Class> &classes)

Format the Classification outputs into a string.

Parameters

classes – All the classification class predictions, as returned by GetClassificationResults().

Returns

a string with all predictions in a line-delimited list with ids and scores for each classification.

std::vector<Class> GetClassificationResults(const float *scores, ssize_t scores_count, float threshold = -std::numeric_limits<float>::infinity(), size_t top_k = std::numeric_limits<size_t>::max())

Converts a classification output tensor into a list of ordered classes.

Parameters
  • scores – The dequantized output tensor.

  • scores_count – The number of scores in the output (the size of the output tensor).

  • threshold – The score threshold for results. All returned results have a score greater-than-or-equal-to this value.

  • top_k – The maximum number of predictions to return.

Returns

The top_k Class predictions (id, score), ordered by score (first element has the highest score).

std::vector<Class> GetClassificationResults(tflite::MicroInterpreter *interpreter, float threshold = -std::numeric_limits<float>::infinity(), size_t top_k = std::numeric_limits<size_t>::max())

Gets results from a classification model as a list of ordered classes.

Parameters
  • interpreter – The already-invoked interpreter for your classification model.

  • threshold – The score threshold for results. All returned results have a score greater-than-or-equal-to this value.

  • top_k – The maximum number of predictions to return.

Returns

The top_k Class predictions (id, score), ordered by score (first element has the highest score).

bool ClassificationInputNeedsPreprocessing(const TfLiteTensor &input_tensor)

Checks whether an input tensor needs pre-processing for classification.

Parameters

intput_tensor – The tensor intended as input for a classification model.

Returns

True if the input tensor requires normalization AND quantization (you should run ClassificationPreprocess()); false otherwise.

bool ClassificationPreprocess(TfLiteTensor *input_tensor)

Performs normalization and quantization pre-processing on the given tensor.

Parameters

input_tensor – The tensor you want to pre-process for a clasification model.

Returns

True upon success; false if the tensor type is the wrong format.

Object detection

These APIs simplify the post-processing for object detection models.

Example (from examples/detect_objects_file/):

namespace coralmicro {
namespace {
constexpr char kModelPath[] =
    "/models/tf2_ssd_mobilenet_v2_coco17_ptq_edgetpu.tflite";
constexpr char kImagePath[] = "/examples/detect_objects_file/cat_300x300.rgb";
constexpr int kTensorArenaSize = 8 * 1024 * 1024;
STATIC_TENSOR_ARENA_IN_SDRAM(tensor_arena, kTensorArenaSize);

void Main() {
  printf("Detect Image Example!\r\n");
  // Turn on Status LED to show the board is on.
  LedSet(Led::kStatus, true);

  std::vector<uint8_t> model;
  if (!LfsReadFile(kModelPath, &model)) {
    printf("ERROR: Failed to load %s\r\n", kModelPath);
    return;
  }

  auto tpu_context = EdgeTpuManager::GetSingleton()->OpenDevice();
  if (!tpu_context) {
    printf("ERROR: Failed to get EdgeTpu context\r\n");
    return;
  }

  tflite::MicroErrorReporter error_reporter;
  tflite::MicroMutableOpResolver<3> resolver;
  resolver.AddDequantize();
  resolver.AddDetectionPostprocess();
  resolver.AddCustom(kCustomOp, RegisterCustomOp());

  tflite::MicroInterpreter interpreter(tflite::GetModel(model.data()), resolver,
                                       tensor_arena, kTensorArenaSize,
                                       &error_reporter);
  if (interpreter.AllocateTensors() != kTfLiteOk) {
    printf("ERROR: AllocateTensors() failed\r\n");
    return;
  }

  if (interpreter.inputs().size() != 1) {
    printf("ERROR: Model must have only one input tensor\r\n");
    return;
  }

  auto* input_tensor = interpreter.input_tensor(0);
  if (!LfsReadFile(kImagePath, tflite::GetTensorData<uint8_t>(input_tensor),
                   input_tensor->bytes)) {
    printf("ERROR: Failed to load %s\r\n", kImagePath);
    return;
  }

  if (interpreter.Invoke() != kTfLiteOk) {
    printf("ERROR: Invoke() failed\r\n");
    return;
  }

  auto results = tensorflow::GetDetectionResults(&interpreter, 0.6, 3);
  printf("%s\r\n", tensorflow::FormatDetectionOutput(results).c_str());
}
}  // namespace
}  // namespace coralmicro

extern "C" void app_main(void* param) {
  (void)param;
  coralmicro::Main();
  vTaskSuspend(nullptr);
}
namespace coralmicro
namespace tensorflow
template<typename T>
struct BBox
#include <detection.h>

Represents the bounding box of a detected object.

Public Members

T ymin

The box y-minimum (top-most) point.

T xmin

The box x-minimum (left-most) point.

T ymax

The box y-maximum (bottom-most) point.

T xmax

The box x-maximum (right-most) point.

struct Object
#include <detection.h>

Represents a detected object.

Public Members

int id

The class label id.

float score

The prediction score.

BBox<float> bbox

The bounding-box (ymin,xmin,ymax,xmax).

Functions

std::string FormatDetectionOutput(const std::vector<Object> &objects)

Formats the detection outputs into a string.

Parameters

object – A vector with all the objects in an object detection output.

Returns

A description of all detected objects.

std::vector<Object> GetDetectionResults(const float *bboxes, const float *ids, const float *scores, size_t count, float threshold = -std::numeric_limits<float>::infinity(), size_t top_k = std::numeric_limits<size_t>::max())

Converts detection output tensors into a vector of Objects.

Parameters
  • bboxes – The output tensor for all detected bounding boxes in box-corner encoding, for example: (ymin1,xmin1,ymax1,xmax1,ymin2,xmin2,…).

  • ids – The output tensor for all label IDs.

  • scores – The output tensor for all scores.

  • count – The number of detected objects (all tensors defined above have valid data for this number of objects).

  • threshold – The score threshold for results. All returned results have a score greater-than-or-equal-to this value.

  • top_k – The maximum number of predictions to return.

Returns

The top_k object predictions (id, score, BBox), ordered by score (first element has the highest score).

std::vector<Object> GetDetectionResults(tflite::MicroInterpreter *interpreter, float threshold = -std::numeric_limits<float>::infinity(), size_t top_k = std::numeric_limits<size_t>::max())

Gets results from a detection model as a vector of Objects.

Parameters
  • interpreter – The already-invoked interpreter for your detection model.

  • threshold – The score threshold for results. All returned results have a score greater-than-or-equal-to this value.

  • top_k – The maximum number of predictions to return.

Returns

The top_k object predictions (id, score, BBox), ordered by score (first element has the highest score).

Pose estimation

These APIs not only simplify the post-processing for pose estimation with PoseNet, but also optimize execution of the post-processing layers on the MCU with a custom op (because the post-processing ops are not compatible with the Edge TPU).

So when running PoseNet, in addition to specifying the kCustomOp for the Edge TPU, you should also register the kPosenetDecoderOp provided here.

Example (from examples/detect_poses/):

namespace coralmicro {
namespace {

constexpr int kModelArenaSize = 1 * 1024 * 1024;
constexpr int kExtraArenaSize = 1 * 1024 * 1024;
constexpr int kTensorArenaSize = kModelArenaSize + kExtraArenaSize;
STATIC_TENSOR_ARENA_IN_SDRAM(tensor_arena, kTensorArenaSize);
constexpr char kModelPath[] =
    "/models/posenet_mobilenet_v1_075_324_324_16_quant_decoder_edgetpu.tflite";
constexpr char kTestInputPath[] = "/models/posenet_test_input_324.bin";

void Main() {
  printf("Posenet Example!\r\n");
  // Turn on Status LED to show the board is on.
  LedSet(Led::kStatus, true);

  tflite::MicroErrorReporter error_reporter;
  TF_LITE_REPORT_ERROR(&error_reporter, "Posenet!");
  // Turn on the TPU and get it's context.
  auto tpu_context =
      EdgeTpuManager::GetSingleton()->OpenDevice(PerformanceMode::kMax);
  if (!tpu_context) {
    printf("ERROR: Failed to get EdgeTpu context\r\n");
    vTaskSuspend(nullptr);
  }
  // Reads the model and checks version.
  std::vector<uint8_t> posenet_tflite;
  if (!LfsReadFile(kModelPath, &posenet_tflite)) {
    TF_LITE_REPORT_ERROR(&error_reporter, "Failed to load model!");
    vTaskSuspend(nullptr);
  }
  auto* model = tflite::GetModel(posenet_tflite.data());
  if (model->version() != TFLITE_SCHEMA_VERSION) {
    TF_LITE_REPORT_ERROR(&error_reporter,
                         "Model schema version is %d, supported is %d",
                         model->version(), TFLITE_SCHEMA_VERSION);
    vTaskSuspend(nullptr);
  }
  // Creates a micro interpreter.
  tflite::MicroMutableOpResolver<2> resolver;
  resolver.AddCustom(kCustomOp, RegisterCustomOp());
  resolver.AddCustom(kPosenetDecoderOp, RegisterPosenetDecoderOp());
  auto interpreter = tflite::MicroInterpreter{
      model, resolver, tensor_arena, kTensorArenaSize, &error_reporter};
  if (interpreter.AllocateTensors() != kTfLiteOk) {
    TF_LITE_REPORT_ERROR(&error_reporter, "AllocateTensors failed.");
    vTaskSuspend(nullptr);
  }
  auto* posenet_input = interpreter.input(0);
  // Runs posenet on a test image.
  printf("Getting outputs for posenet test input\r\n");
  std::vector<uint8_t> posenet_test_input_bin;
  if (!LfsReadFile(kTestInputPath, &posenet_test_input_bin)) {
    TF_LITE_REPORT_ERROR(&error_reporter, "Failed to load test input!");
    vTaskSuspend(nullptr);
  }
  if (posenet_input->bytes != posenet_test_input_bin.size()) {
    TF_LITE_REPORT_ERROR(&error_reporter,
                         "Input tensor length doesn't match canned input\r\n");
    vTaskSuspend(nullptr);
  }
  memcpy(tflite::GetTensorData<uint8_t>(posenet_input),
         posenet_test_input_bin.data(), posenet_test_input_bin.size());
  if (interpreter.Invoke() != kTfLiteOk) {
    TF_LITE_REPORT_ERROR(&error_reporter, "Invoke failed.");
    vTaskSuspend(nullptr);
  }
  auto test_image_output =
      tensorflow::GetPosenetOutput(&interpreter, /*threshold=*/0.5);
  printf("%s\r\n", tensorflow::FormatPosenetOutput(test_image_output).c_str());
  // Starts the camera for live poses.
  CameraTask::GetSingleton()->SetPower(true);
  CameraTask::GetSingleton()->Enable(CameraMode::kStreaming);
  printf("Starting live posenet\r\n");
  auto model_height = posenet_input->dims->data[1];
  auto model_width = posenet_input->dims->data[2];
  for (;;) {
    CameraFrameFormat fmt{
        /*fmt=*/CameraFormat::kRgb,
        /*filter=*/CameraFilterMethod::kBilinear,
        /*rotation=*/CameraRotation::k270,
        /*width=*/model_width,
        /*height=*/model_height,
        /*preserve_ratio=*/false,
        /*buffer=*/tflite::GetTensorData<uint8_t>(posenet_input)};
    if (!CameraTask::GetSingleton()->GetFrame({fmt})) {
      TF_LITE_REPORT_ERROR(&error_reporter, "Failed to get image from camera.");
      break;
    }
    if (interpreter.Invoke() != kTfLiteOk) {
      TF_LITE_REPORT_ERROR(&error_reporter, "Invoke failed.");
      break;
    }
    auto output = tensorflow::GetPosenetOutput(&interpreter,
                                               /*threshold=*/0.5);
    printf("%s\r\n", tensorflow::FormatPosenetOutput(output).c_str());
    vTaskDelay(pdMS_TO_TICKS(100));
  }
  CameraTask::GetSingleton()->SetPower(false);
}

}  // namespace
}  // namespace coralmicro

extern "C" void app_main(void* param) {
  (void)param;
  coralmicro::Main();
  vTaskSuspend(nullptr);
}
namespace coralmicro

Functions

TfLiteRegistration *RegisterPosenetDecoderOp()

Returns pointer to an instance of tflite::TfLiteRegistration to handle the custom op for post-processing PoseNet output tensors on the MCU. Pass this to tflite::MicroMutableOpResolver::AddCustom().

Variables

constexpr char kPosenetDecoderOp[] = "PosenetDecoderOp"

PoseNet custom op name. Pass this to tflite::MicroMutableOpResolver::AddCustom().

namespace coralmicro
namespace tensorflow

Functions

std::string FormatPosenetOutput(const std::vector<Pose> &poses)

Formats all the PoseNet output into a string.

Parameters

poses – A vector contains all the poses in a posenet output.

Returns

A string showing the posenet’s output.

std::vector<Pose> GetPosenetOutput(tflite::MicroInterpreter *interpreter, float threshold = -std::numeric_limits<float>::infinity())

Gets the results from a PoseNet model in the form of a vector of poses.

After you invoke the interpreter, pass it to this function to get structured pose results.

Parameters
  • interpreter – The already-invoked interpreter for your PoseNet model.

  • threshold – The overall pose score threshold for results.

Returns

All detected poses with an overall score greater-than-or-equal-to the threshold.

Variables

constexpr int kKeypoints = 17

Number of keypoints in each pose.

const char *const KeypointTypes[] = {"NOSE", "LEFT_EYE", "RIGHT_EYE", "LEFT_EAR", "RIGHT_EAR", "LEFT_SHOULDER", "RIGHT_SHOULDER", "LEFT_ELBOW", "RIGHT_ELBOW", "LEFT_WRIST", "RIGHT_WRIST", "LEFT_HIP", "RIGHT_HIP", "LEFT_KNEE", "RIGHT_KNEE", "LEFT_ANKLE", "RIGHT_ANKLE",}

A map of keypoint index to the keypoint name.

struct Keypoint
#include <posenet.h>

The location and score of a pose keypoint.

Public Members

float x

The keypoint’s x position, relative to the image size (0 to 1.0).

float y

The keypoint’s y position, relative to the image size (0 to 1.0).

float score

The keypoint’s prediction score (0 to 1.0).

struct Pose
#include <posenet.h>

Represents an individual pose.

Public Members

float score

The pose’s overall prediction score.

Keypoint keypoints[kKeypoints]

An array of keypoints in this pose.

Audio Classification

The following APIs assist with running audio classification models on the Dev Board Micro, either on CPU or Edge TPU. For supported models, see the audio classification models.

For an example, see examples/classify_speech/.

namespace coralmicro
namespace tensorflow

Enums

enum AudioModel

Supported models.

Values:

enumerator kYAMNet

YamNet without the frontend.

enumerator kKeywordDetector

Keyword detector (or “Keyword Spotter”).

Functions

template<bool tForTpu>
auto SetupYamNetResolver()

Sets up the MicroMutableOpResolver with ops required for YamNet.

Template Parameters

tForTpu – If true the Resolver will be setup for TPU else CPU.

Returns

A tflite::MicroMutableOpResolver that is prepared for the YamNet model.

bool PrepareAudioFrontEnd(FrontendState *frontend_state, AudioModel model_type)

Prepares the input preprocess engine for TensorFlow to converts raw audio data to spectrogram. This function must be called before PreprocessAudioInput() is called.

Parameters
  • frontend_state – The FrontendState struct to populate.

  • model_type – The type of audio model.

Returns

true on FrontendPopulateState() success, else false.

void YamNetPreprocessInput(const int16_t *audio_data, TfLiteTensor *input_tensor, FrontendState *frontend_state)

Performs input preprocessing to convert raw input to spectrogram.

Parameters
  • audio_data – An array of signed int16 audio data.

  • input_tensor – The tensor where the preprocessed spectrogram data is stored.

  • frontend_state – The populated frontend state that you want to preprocess the input tensor, must not be nullptr.

void KeywordDetectorPreprocessInput(const int16_t *audio_data, TfLiteTensor *input_tensor, FrontendState *frontend_state)

Performs input preprocessing to convert raw audio input to spectrogram.

Parameters
  • audio_data – An array of signed int16 audio data.

  • input_tensor – The tensor you want to pre-process for a TensorFlow model, must not be nullptr.

  • frontend_state – The populated frontend state that you want to preprocess the input tensor, must not be nullptr.

Variables

constexpr int kYamnetSampleRate = 16000
constexpr int kYamnetSampleRateMs = kYamnetSampleRate / 1000
constexpr int kYamnetDurationMs = 975
constexpr int kYamnetAudioSize = kYamnetSampleRate * kYamnetDurationMs / 1000
constexpr int kYamnetFeatureSliceSize = 64
constexpr int kYamnetFeatureSliceCount = 96
constexpr int kYamnetFeatureElementCount = (kYamnetFeatureSliceSize * kYamnetFeatureSliceCount)
constexpr int kYamnetFeatureSliceStrideMs = 10
constexpr int kYamnetFeatureSliceDurationMs = 25
constexpr int kKeywordDetectorSampleRate = 16000
constexpr int kKeywordDetectorSampleRateMs = kKeywordDetectorSampleRate / 1000
constexpr int kKeywordDetectorDurationMs = 2000
constexpr int kKeywordDetectorAudioSize = kKeywordDetectorSampleRate * kKeywordDetectorDurationMs / 1000
constexpr int kKeywordDetectorFeatureSliceSize = 32
constexpr int kKeywordDetectorFeatureSliceCount = 198
constexpr int kKeywordDetectorFeatureElementCount = (kKeywordDetectorFeatureSliceSize * kKeywordDetectorFeatureSliceCount)
constexpr int kKeywordDetectorFeatureSliceStrideMs = 10
constexpr int kKeywordDetectorFeatureSliceDurationMs = 25

Utilities

The following functions help with some common tasks during inferencing, such as manipulate images and tensors.

namespace coralmicro
namespace tensorflow
struct ImageDims
#include <utils.h>

Represents the dimensions of an image.

Public Members

int height

Pixel height.

int width

Pixel width.

int depth

Channel depth.

Functions

inline bool operator==(const ImageDims &a, const ImageDims &b)

Operator == to compares 2 ImageDims object.

inline int ImageSize(const ImageDims &dims)

Gets an ImageDims’s size.

bool ResizeImage(const ImageDims &in_dims, const uint8_t *uin, const ImageDims &out_dims, uint8_t *uout)

Resizes a bitmap image.

Parameters
  • in_dims – The current dimensions for image uin.

  • uin – The input image location.

  • out_dims – The desired dimensions for image uout.

  • uout – The output image location.

inline int TensorSize(TfLiteTensor *tensor)

Gets the size of a tensor.

Parameters

tensor – The tensor to get the size.

Returns

The size of the tensor.

template<typename I, typename O>
void Dequantize(int tensor_size, I *tensor_data, O *dequant_data, float scale, float zero_point)

Dequantizes data.

Parameters
  • tensor_size – The tensor’s size.

  • tensor_data – The tensor’s data.

  • dequant_data – The buffer to return the dequantized data to.

  • scale – The scale of the input tensor.

  • zero_point – The zero point of the input tensor.

Template Parameters
  • I – The data type of tensor_data.

  • O – The desired data type of the dequantized data. Note: You should instead use DequantizeTensor().

template<typename T>
std::vector<T> DequantizeTensor(TfLiteTensor *tensor)

Dequantizes a tensor.

Parameters

tensor – The tensor to dequantize.

Template Parameters

T – The desired output type of the dequantized data. When using a model adapter API such as GetClassificationResults(), this dequantization is done for you.

Returns

A vector of quantized data.

Defines

STATIC_TENSOR_ARENA_IN_SDRAM(name, size)

Allocates a uint8_t tensor arena statically in the Dev Board Micro SDRAM (max of 64 MB). This is slightly slower than OCRAM due to off-chip I/O overhead costs.

Parameters
  • name – The variable name for this allocation.

  • size – The byte size to allocate. This macro automatically aligns the size to 16 bits.

STATIC_TENSOR_ARENA_IN_OCRAM(name, size)

Allocates a uint8_t tensor arena statically in the RT1176 on-chip RAM (max of 1.25 MB).

Parameters
  • name – The variable name for this allocation.

  • size – The byte size to allocate. This macro automatically aligns the size to 16 bits.