TensorFlow Lite Micro APIs
The Coral Dev Board Micro allows you to run two types of TensorFlow models: TensorFlow Lite Micro models that run on entirely the microcontroller (MCU) and TensorFlow Lite models that are compiled for acceleration on the Coral Edge TPU. Although you can run TensorFlow Lite Micro models on either MCU core (M4 or M7), currently, you must execute Edge TPU models from the M7.
Note
If you have experience with TensorFlow Lite on other platforms (including other Coral boards/accelerators), a lot of the code to run inference on the Dev Board Micro should be familiar, but the APIs are actually different for microcontrollers, so your code is not 100% portable.
To run any TensorFlow Lite model on the Dev Board Micro, you must use the
TensorFlow interpreter provided by
TensorFlow Lite for Microcontrollers
(TFLM): tflite::MicroInterpreter
.
If you’re running a model on the Edge TPU, the only difference compared to
running a model on the MCU is that you
need to specify the Edge TPU custom op when you instantiate the
tflite::MicroInterpreter
(and your model must be compiled for the Edge TPU).
The following steps describe the basic procedures to run inference on the
Dev Board Micro using either type of model.
First, you need to perform some setup:
If using on the Edge TPU, power on the Edge TPU with
OpenDevice()
:auto tpu_context = EdgeTpuManager::GetSingleton()->OpenDevice(); if (!tpu_context) { printf("ERROR: Failed to get EdgeTpu context\r\n"); }
Load your
.tflite
model from a file into a byte array withLfsReadFile()
:constexpr char kModelPath[] = "/models/tf2_ssd_mobilenet_v2_coco17_ptq_edgetpu.tflite"; std::vector<uint8_t> model; if (!LfsReadFile(kModelPath, &model)) { printf("ERROR: Failed to load %s\r\n", kModelPath); }
Note: Some micro ML apps instead load their model from a C array that’s compiled with the app, which is also an option, but that’s intended for microcontrollers without a filesystem (Dev Board Micro has a littlefs filesystem).
Specify each of the TensorFlow ops required by your model with a
MicroMutableOpResolver
. When using a model compiled for the Edge TPU, you must include thekCustomOp
withAddCustom()
. For example:tflite::MicroMutableOpResolver<3> resolver; resolver.AddDequantize(); resolver.AddDetectionPostprocess(); resolver.AddCustom(kCustomOp, RegisterCustomOp());
Specify the memory arena required for your model’s input, output, and intermediate tensors. To ensure 16-bit alignment (required by TFLM) and avoid running out of heap space, you should use either the
STATIC_TENSOR_ARENA_IN_SDRAM
orSTATIC_TENSOR_ARENA_IN_OCRAM
macro to allocate your tensor arena:constexpr int kTensorArenaSize = 8 * 1024 * 1024; STATIC_TENSOR_ARENA_IN_SDRAM(tensor_arena, kTensorArenaSize);
Selecting the best arena size depends on the model and requires some trial-and-error: Just start with a small number like 1024 and run it; TFLM will throw an error at runtime and tell you the size you actually need.
Instantiate a
MicroInterpreter
, passing it your model, op resolver, tensor arena, and aMicroErrorReporter
:tflite::MicroErrorReporter error_reporter; tflite::MicroInterpreter interpreter(tflite::GetModel(model.data()), resolver, tensor_arena, kTensorArenaSize, &error_reporter);
Allocate all model tensors with
AllocateTensors()
:if (interpreter.AllocateTensors() != kTfLiteOk) { printf("ERROR: AllocateTensors() failed\r\n"); }
Now you’re ready to run each inference as follows:
Get the allocated input tensor with
input_tensor()
and fill it with your input data. For example, if you’re using the Dev Board Micro camera, you can simply set the input tensor as thebuffer
for yourCameraFrameFormat
(seeexamples/detect_faces/
). Or you can copy your input data usingGetTensorData()
andstd::memcpy
like this:auto* input_tensor = interpreter.input_tensor(0); std::memcpy(tflite::GetTensorData<uint8_t>(input_tensor), image.data(), image.size());
Execute the model with
Invoke()
:if (interpreter.Invoke() != kTfLiteOk) { printf("ERROR: Invoke() failed\r\n"); }
Similar to writing the input tensor, you can then read the output tensor with
GetTensorData()
by passing itoutput_tensor()
.However, instead of processing this output data yourself, you can use the APIs below that correspond to the type of model you’re running. For example, if you’re running an object detection model, instead of reading the output tensor directly, call
GetDetectionResults()
and pass it yourMicroInterpreter
. This function returns anObject
for each detected object, which specifies the detected object’s label id, prediction score, and bounding-box coordinates:auto results = tensorflow::GetDetectionResults(&interpreter, 0.6, 3); printf("%s\r\n", tensorflow::FormatDetectionOutput(results).c_str());
See the following documentation for more code examples, each of which is included from the coralmicro examples, which you can browse at coralmicro/examples/.
Also check out the TensorFlow Lite for Microcontrollers documentation.
TFLM interpreter
This is just a small set of APIs from TensorFlow Lite for Microcontrollers
(TFLM) that represent the core APIs you need to run inference on the Dev Board
Micro. You can see the rest of the TFLM APIs in
coralmicro/third_party/tflite-micro/
.
Note
The version of TFLM included in coralmicro is not continuously updated, so some APIs might be different from the latest version of TFLM on GitHub.
For usage examples, see the following sections, such as for image classification.
-
class
tflite
::
MicroInterpreter
¶ Encapsulates a pre-trained model and drives the model inference.
Note
This class is not thread-safe. The client is responsible for ensuring serialized interaction to avoid data races and undefined behavior.
Public Functions
-
MicroInterpreter
(const Model *model, const MicroOpResolver &op_resolver, uint8_t *tensor_arena, size_t tensor_arena_size, ErrorReporter *error_reporter, MicroResourceVariables *resource_variables = nullptr, MicroProfiler *profiler = nullptr)¶ Constructor. Creates an instance with an allocated tensor arena.
The lifetime of the model, op resolver, tensor arena, error reporter, resource variables, and profiler must be at least as long as that of the interpreter object, since the interpreter may need to access them at any time. This means that you should usually create them with the same scope as each other, for example having them all allocated on the stack as local variables through a top-level function. The interpreter doesn’t do any deallocation of any of the pointed-to objects, ownership remains with the caller.
- Parameters
model – A trained TensorFlow Lite model.
op_resolver – The op resolver that contains all ops used by the model. This is usually an instance of
tflite::MicroMutableOpResolver
.tensor_arena – The allocated memory for all intermediate tensor data.
tensor_arena_size – The size of
tensor_arena
.error_reporter – Object to use for error reports.
resource_variables – Handles assign/read ops for resource variables. See Resource variable docs.
profiler – Handles profiling for op kernels and TFLM routines. See Profiling docs.
-
MicroInterpreter
(const Model *model, const MicroOpResolver &op_resolver, MicroAllocator *allocator, ErrorReporter *error_reporter, MicroResourceVariables *resource_variables = nullptr, MicroProfiler *profiler = nullptr)¶ Constructor. Creates an instance using an existing MicroAllocator instance.
This constructor should be used when creating an allocator that needs to have allocation handled in more than one interpreter or for recording allocations inside the interpreter. The lifetime of the allocator must be as long as that of the interpreter object.
- Parameters
model – A trained TensorFlow Lite model.
op_resolver – The op resolver that contains all ops used by the model. This is usually an instance of
tflite::MicroMutableOpResolver
.allocator – The object that allocates all intermediate tensor data.
error_reporter – Object to use for error reports.
resource_variables – Handles assign/read ops for resource variables. See Resource variable docs.
profiler – Handles profiling for op kernels and TFLM routines. See Profiling docs.
-
TfLiteStatus
AllocateTensors
()¶ Allocates all the model’s necessary input, output and intermediate tensors.
This will redim dependent tensors using the input tensor dimensionality as given. This is relatively expensive. This must be called after the interpreter has been created and before running inference (and accessing tensor buffers), and must be called again if (and only if) an input tensor is resized.
- Returns
Atatus of success or failure. Will fail if any of the ops in the model (other than those which were rewritten by delegates, if any) are not supported by the Interpreter’s OpResolver.
-
TfLiteStatus
Invoke
()¶ Invokes the model to run inference using allocated input tensors.
In order to support partial graph runs for strided models, this can return values other than kTfLiteOk and kTfLiteError.
-
TfLiteStatus
SetMicroExternalContext
(void *external_context_payload)¶ This is the recommended API for an application to pass an external payload pointer as an external context to kernels. The life time of the payload pointer should be at least as long as this interpreter. TFLM supports only one external context.
-
TfLiteTensor *
input
(size_t index)¶ Gets a mutable pointer to an input tensor.
- Parameters
index – The index position of the input tensor. Must be between 0 and
inputs_size()
.- Returns
The input tensor.
-
inline size_t
inputs_size
() const¶ Gets the size of the input tensors.
-
inline const flatbuffers::Vector<int32_t> &
inputs
() const¶ Gets a read-only list of all inputs.
-
template<class
T
>
inline T *typed_input_tensor
(int tensor_index)¶ Gets a mutable pointer into the data of a given input tensor.
The given index must be between 0 and
inputs_size()
.
-
TfLiteTensor *
output
(size_t index)¶ Gets a mutable pointer to an output tensor.
- Parameters
index – The index position of the output tensor. Must be between 0 and
outputs_size()
.- Returns
The output tensor.
-
inline size_t
outputs_size
() const¶ Gets the size of the output tensors.
-
inline const flatbuffers::Vector<int32_t> &
outputs
() const¶ Gets a read-only list of all outputs.
-
template<class
T
>
inline T *typed_output_tensor
(int tensor_index)¶ Gets a mutable pointer into the data of a given output tensor.
The given index must be between 0 and
outputs_size()
.
-
TfLiteStatus
ResetVariableTensors
()¶ Reset all variable tensors to the default value.
-
TfLiteStatus
PrepareNodeAndRegistrationDataFromFlatbuffer
()¶ Populates node and registration pointers representing the inference graph of the model from values inside the flatbuffer (loaded from the TfLiteModel instance). Persistent data (e.g. operator data) is allocated from the arena.
-
inline size_t
arena_used_bytes
() const¶ For debugging only. Returns the actual used arena in bytes. This method gives the optimal arena size. It’s only available after
AllocateTensors
has been called. Note that normallytensor_arena
requires 16 bytes alignment to fully utilize the space. If it’s not the case, the optimial arena size would be arena_used_bytes() + 16.
-
-
inline const tflite::Model *
tflite
::
GetModel
(const void *buf)¶ Creates a
Model
object with the given model data, which you need for thetflite::MicroInterpreter
constructor.- Parameters
buf – The model data, either loaded from a C array or from a
.tflite
file.- Returns
The model object to use with the
tflite::MicroInterpreter
constructor.
[micro_mutable_op_resolver.h source]
-
template<unsigned int
tOpCount
>
classtflite
::
MicroMutableOpResolver
: public tflite::MicroOpResolver¶ Maps ops in the loaded model to executable functions on the device.
You must use this object to specify each of the ops required by your model (using the various
Add...
functions), and then pass this to thetflite::MicroInterpreter
constructor.Public Functions
-
inline TfLiteStatus
AddCustom
(const char *name, TfLiteRegistration *registration)¶ Registers a Custom Operator with the MicroOpResolver.
Only the first call for a given name will be successful. That is, if this function is called again for a previously added Custom Operator, the MicroOpResolver will be unchanged and this function will return kTfLiteError.
- Parameters
name – Name of the custom op.
registration – Handler for the custom op.
- Returns
kTfLiteOk
if successful;kTfLiteError
otherwise.
-
inline TfLiteStatus
Note
The tflite::MicroMutableOpResolver
has a long list of Add...
functions to specify the ops that you need for your model. To see them
all, refer to the micro_mutable_op_resolver.h source code.
[micro_error_reporter.h source]
-
class
tflite
::
MicroErrorReporter
: public ErrorReporter¶ Reports errors for
MicroInterpreter
.
Edge TPU runtime
Note
The Edge TPU is not available within M4 programs.
These APIs provide access to the Edge TPU on the Dev Board Micro.
Anytime you want to use the Edge TPU for acceleration with
MicroInterpreter
, you need to do two things:
Start the Edge TPU with
OpenDevice()
.Register the Edge TPU custom op with your interpreter by passing
kCustomOp
andRegisterCustomOp()
totflite::MicroMutableOpResolver::AddCustom()
.
Example (from examples/classify_images_file/):
auto tpu_context = EdgeTpuManager::GetSingleton()->OpenDevice();
if (!tpu_context) {
printf("ERROR: Failed to get EdgeTpu context\r\n");
return;
}
tflite::MicroErrorReporter error_reporter;
tflite::MicroMutableOpResolver<1> resolver;
resolver.AddCustom(kCustomOp, RegisterCustomOp());
tflite::MicroInterpreter interpreter(tflite::GetModel(model.data()), resolver,
tensor_arena, kTensorArenaSize,
&error_reporter);
Note
Unlike the libcoral C++ API, when using this coralmicro C++ API, you
do not need to pass the EdgeTpuContext
to the
tflite::MicroInterpreter
, but the context must be opened and the custom
op must be registered before you create an interpreter. (This is different
because libcoral is based on TensorFlow Lite and coralmicro is based on
TensorFlow Lite for Microcontrollers.)
-
namespace
coralmicro
-
class
EdgeTpuContext
¶ - #include <edgetpu_manager.h>
This class is a representation of the Edge TPU device, so there is one shared
EdgeTpuContext
used by all model interpreters.Instances of this should be allocated with
EdgeTpuManager::OpenDevice()
.The
EdgeTpuContext
can be shared among multiple software components, and the life of this object is directly tied to the Edge TPU power, so the Edge TPU powers down after the lastEdgeTpuContext
reference leaves scope.The lifetime of the
EdgeTpuContext
must be longer than all associatedtflite::MicroInterpreter
instances.
-
class
EdgeTpuManager
¶ - #include <edgetpu_manager.h>
Singleton Edge TPU manager for allocating new instances of
EdgeTpuContext
.Public Functions
-
std::shared_ptr<EdgeTpuContext>
OpenDevice
(PerformanceMode mode = PerformanceMode::kHigh)¶ Gets the default Edge TPU device (and starts it if necessary).
The Edge TPU device (represented by
EdgeTpuContext
) can be shared among multiple software components, and theEdgeTpuManager
is a singleton object, so you should always call this function like this:auto tpu_context = EdgeTpuManager::GetSingleton()->OpenDevice();
- Parameters
mode – The
PerformanceMode
to use for the Edge TPU. Options are:kMax
(500Mhz),kHigh
(250Mhz),kMedium
(125Mhz), orkLow
(63Mhz). If omitted, the default iskHigh
. Caution: If you set the performance mode tokMax
, it can increase the Edge TPU inferencing speed, but it can also make the Edge TPU module hotter, which might cause burns if touched.- Returns
A shared pointer to Edge TPU device. The shared_ptr can point to nullptr in case of error.
-
std::optional<float>
GetTemperature
()¶ Gets the current Edge TPU junction temperature.
- Returns
The temperature in Celcius, or
std::nullopt
ifEdgeTpuContext
is empty.
Public Static Functions
-
static inline EdgeTpuManager *
GetSingleton
()¶ Gets a pointer to the
EdgeTpuManager
singleton object.
-
std::shared_ptr<EdgeTpuContext>
-
class
-
namespace
coralmicro
Functions
-
TfLiteRegistration *
RegisterCustomOp
()¶ Returns pointer to an instance of
tflite::TfLiteRegistration
to handle Edge TPU custom ops. Pass this totflite::MicroMutableOpResolver::AddCustom()
.
Variables
-
constexpr char
kCustomOp
[] = "edgetpu-custom-op"¶ Edge TPU custom op. Pass this to
tflite::MicroMutableOpResolver::AddCustom()
.
-
TfLiteRegistration *
Image classification
These APIs simplify the pre- and post-processing for image classification models.
Example (from examples/classify_images_file/):
namespace coralmicro {
namespace {
constexpr char kModelPath[] =
"/models/mobilenet_v1_1.0_224_quant_edgetpu.tflite";
constexpr char kImagePath[] = "/examples/classify_images_file/cat_224x224.rgb";
constexpr int kTensorArenaSize = 1024 * 1024;
STATIC_TENSOR_ARENA_IN_SDRAM(tensor_arena, kTensorArenaSize);
void Main() {
printf("Classify Image Example!\r\n");
// Turn on Status LED to show the board is on.
LedSet(Led::kStatus, true);
std::vector<uint8_t> model;
if (!LfsReadFile(kModelPath, &model)) {
printf("ERROR: Failed to load %s\r\n", kModelPath);
return;
}
// [start-sphinx-snippet:edgetpu]
auto tpu_context = EdgeTpuManager::GetSingleton()->OpenDevice();
if (!tpu_context) {
printf("ERROR: Failed to get EdgeTpu context\r\n");
return;
}
tflite::MicroErrorReporter error_reporter;
tflite::MicroMutableOpResolver<1> resolver;
resolver.AddCustom(kCustomOp, RegisterCustomOp());
tflite::MicroInterpreter interpreter(tflite::GetModel(model.data()), resolver,
tensor_arena, kTensorArenaSize,
&error_reporter);
// [end-sphinx-snippet:edgetpu]
if (interpreter.AllocateTensors() != kTfLiteOk) {
printf("ERROR: AllocateTensors() failed\r\n");
return;
}
if (interpreter.inputs().size() != 1) {
printf("ERROR: Model must have only one input tensor\r\n");
return;
}
auto* input_tensor = interpreter.input_tensor(0);
if (!LfsReadFile(kImagePath, tflite::GetTensorData<uint8_t>(input_tensor),
input_tensor->bytes)) {
printf("ERROR: Failed to load %s\r\n", kImagePath);
return;
}
if (interpreter.Invoke() != kTfLiteOk) {
printf("ERROR: Invoke() failed\r\n");
return;
}
auto results = tensorflow::GetClassificationResults(&interpreter, 0.0f, 3);
for (auto& result : results)
printf("Label ID: %d Score: %f\r\n", result.id, result.score);
}
} // namespace
} // namespace coralmicro
extern "C" void app_main(void* param) {
(void)param;
coralmicro::Main();
vTaskSuspend(nullptr);
}
-
namespace
coralmicro
-
namespace
tensorflow
¶ -
struct
Class
¶ - #include <classification.h>
Represents a classification result.
Functions
-
std::string
FormatClassificationOutput
(const std::vector<tensorflow::Class> &classes)¶ Format the Classification outputs into a string.
- Parameters
classes – All the classification class predictions, as returned by
GetClassificationResults()
.- Returns
a string with all predictions in a line-delimited list with ids and scores for each classification.
-
std::vector<Class>
GetClassificationResults
(const float *scores, ssize_t scores_count, float threshold = -std::numeric_limits<float>::infinity(), size_t top_k = std::numeric_limits<size_t>::max())¶ Converts a classification output tensor into a list of ordered classes.
- Parameters
scores – The dequantized output tensor.
scores_count – The number of scores in the output (the size of the output tensor).
threshold – The score threshold for results. All returned results have a score greater-than-or-equal-to this value.
top_k – The maximum number of predictions to return.
- Returns
The top_k Class predictions (id, score), ordered by score (first element has the highest score).
-
std::vector<Class>
GetClassificationResults
(tflite::MicroInterpreter *interpreter, float threshold = -std::numeric_limits<float>::infinity(), size_t top_k = std::numeric_limits<size_t>::max())¶ Gets results from a classification model as a list of ordered classes.
- Parameters
interpreter – The already-invoked interpreter for your classification model.
threshold – The score threshold for results. All returned results have a score greater-than-or-equal-to this value.
top_k – The maximum number of predictions to return.
- Returns
The top_k Class predictions (id, score), ordered by score (first element has the highest score).
-
bool
ClassificationInputNeedsPreprocessing
(const TfLiteTensor &input_tensor)¶ Checks whether an input tensor needs pre-processing for classification.
- Parameters
intput_tensor – The tensor intended as input for a classification model.
- Returns
True if the input tensor requires normalization AND quantization (you should run ClassificationPreprocess()); false otherwise.
-
bool
ClassificationPreprocess
(TfLiteTensor *input_tensor)¶ Performs normalization and quantization pre-processing on the given tensor.
- Parameters
input_tensor – The tensor you want to pre-process for a clasification model.
- Returns
True upon success; false if the tensor type is the wrong format.
-
struct
-
namespace
Object detection
These APIs simplify the post-processing for object detection models.
Example (from examples/detect_objects_file/):
namespace coralmicro {
namespace {
constexpr char kModelPath[] =
"/models/tf2_ssd_mobilenet_v2_coco17_ptq_edgetpu.tflite";
constexpr char kImagePath[] = "/examples/detect_objects_file/cat_300x300.rgb";
constexpr int kTensorArenaSize = 8 * 1024 * 1024;
STATIC_TENSOR_ARENA_IN_SDRAM(tensor_arena, kTensorArenaSize);
void Main() {
printf("Detect Image Example!\r\n");
// Turn on Status LED to show the board is on.
LedSet(Led::kStatus, true);
std::vector<uint8_t> model;
if (!LfsReadFile(kModelPath, &model)) {
printf("ERROR: Failed to load %s\r\n", kModelPath);
return;
}
auto tpu_context = EdgeTpuManager::GetSingleton()->OpenDevice();
if (!tpu_context) {
printf("ERROR: Failed to get EdgeTpu context\r\n");
return;
}
tflite::MicroErrorReporter error_reporter;
tflite::MicroMutableOpResolver<3> resolver;
resolver.AddDequantize();
resolver.AddDetectionPostprocess();
resolver.AddCustom(kCustomOp, RegisterCustomOp());
tflite::MicroInterpreter interpreter(tflite::GetModel(model.data()), resolver,
tensor_arena, kTensorArenaSize,
&error_reporter);
if (interpreter.AllocateTensors() != kTfLiteOk) {
printf("ERROR: AllocateTensors() failed\r\n");
return;
}
if (interpreter.inputs().size() != 1) {
printf("ERROR: Model must have only one input tensor\r\n");
return;
}
auto* input_tensor = interpreter.input_tensor(0);
if (!LfsReadFile(kImagePath, tflite::GetTensorData<uint8_t>(input_tensor),
input_tensor->bytes)) {
printf("ERROR: Failed to load %s\r\n", kImagePath);
return;
}
if (interpreter.Invoke() != kTfLiteOk) {
printf("ERROR: Invoke() failed\r\n");
return;
}
auto results = tensorflow::GetDetectionResults(&interpreter, 0.6, 3);
printf("%s\r\n", tensorflow::FormatDetectionOutput(results).c_str());
}
} // namespace
} // namespace coralmicro
extern "C" void app_main(void* param) {
(void)param;
coralmicro::Main();
vTaskSuspend(nullptr);
}
-
namespace
coralmicro
-
namespace
tensorflow
-
template<typename
T
>
structBBox
¶ - #include <detection.h>
Represents the bounding box of a detected object.
-
struct
Object
¶ - #include <detection.h>
Represents a detected object.
Functions
-
std::string
FormatDetectionOutput
(const std::vector<Object> &objects)¶ Formats the detection outputs into a string.
- Parameters
object – A vector with all the objects in an object detection output.
- Returns
A description of all detected objects.
-
std::vector<Object>
GetDetectionResults
(const float *bboxes, const float *ids, const float *scores, size_t count, float threshold = -std::numeric_limits<float>::infinity(), size_t top_k = std::numeric_limits<size_t>::max())¶ Converts detection output tensors into a vector of Objects.
- Parameters
bboxes – The output tensor for all detected bounding boxes in box-corner encoding, for example: (ymin1,xmin1,ymax1,xmax1,ymin2,xmin2,…).
ids – The output tensor for all label IDs.
scores – The output tensor for all scores.
count – The number of detected objects (all tensors defined above have valid data for this number of objects).
threshold – The score threshold for results. All returned results have a score greater-than-or-equal-to this value.
top_k – The maximum number of predictions to return.
- Returns
The top_k object predictions (id, score, BBox), ordered by score (first element has the highest score).
-
std::vector<Object>
GetDetectionResults
(tflite::MicroInterpreter *interpreter, float threshold = -std::numeric_limits<float>::infinity(), size_t top_k = std::numeric_limits<size_t>::max())¶ Gets results from a detection model as a vector of Objects.
- Parameters
interpreter – The already-invoked interpreter for your detection model.
threshold – The score threshold for results. All returned results have a score greater-than-or-equal-to this value.
top_k – The maximum number of predictions to return.
- Returns
The top_k object predictions (id, score, BBox), ordered by score (first element has the highest score).
-
template<typename
-
namespace
Pose estimation
These APIs not only simplify the post-processing for pose estimation with PoseNet, but also optimize execution of the post-processing layers on the MCU with a custom op (because the post-processing ops are not compatible with the Edge TPU).
So when running PoseNet, in addition to specifying the
kCustomOp
for the
Edge TPU, you should also register the kPosenetDecoderOp
provided here.
Example (from examples/detect_poses/):
namespace coralmicro {
namespace {
constexpr int kModelArenaSize = 1 * 1024 * 1024;
constexpr int kExtraArenaSize = 1 * 1024 * 1024;
constexpr int kTensorArenaSize = kModelArenaSize + kExtraArenaSize;
STATIC_TENSOR_ARENA_IN_SDRAM(tensor_arena, kTensorArenaSize);
constexpr char kModelPath[] =
"/models/posenet_mobilenet_v1_075_324_324_16_quant_decoder_edgetpu.tflite";
constexpr char kTestInputPath[] = "/models/posenet_test_input_324.bin";
void Main() {
printf("Posenet Example!\r\n");
// Turn on Status LED to show the board is on.
LedSet(Led::kStatus, true);
tflite::MicroErrorReporter error_reporter;
TF_LITE_REPORT_ERROR(&error_reporter, "Posenet!");
// Turn on the TPU and get it's context.
auto tpu_context =
EdgeTpuManager::GetSingleton()->OpenDevice(PerformanceMode::kMax);
if (!tpu_context) {
printf("ERROR: Failed to get EdgeTpu context\r\n");
vTaskSuspend(nullptr);
}
// Reads the model and checks version.
std::vector<uint8_t> posenet_tflite;
if (!LfsReadFile(kModelPath, &posenet_tflite)) {
TF_LITE_REPORT_ERROR(&error_reporter, "Failed to load model!");
vTaskSuspend(nullptr);
}
auto* model = tflite::GetModel(posenet_tflite.data());
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(&error_reporter,
"Model schema version is %d, supported is %d",
model->version(), TFLITE_SCHEMA_VERSION);
vTaskSuspend(nullptr);
}
// Creates a micro interpreter.
tflite::MicroMutableOpResolver<2> resolver;
resolver.AddCustom(kCustomOp, RegisterCustomOp());
resolver.AddCustom(kPosenetDecoderOp, RegisterPosenetDecoderOp());
auto interpreter = tflite::MicroInterpreter{
model, resolver, tensor_arena, kTensorArenaSize, &error_reporter};
if (interpreter.AllocateTensors() != kTfLiteOk) {
TF_LITE_REPORT_ERROR(&error_reporter, "AllocateTensors failed.");
vTaskSuspend(nullptr);
}
auto* posenet_input = interpreter.input(0);
// Runs posenet on a test image.
printf("Getting outputs for posenet test input\r\n");
std::vector<uint8_t> posenet_test_input_bin;
if (!LfsReadFile(kTestInputPath, &posenet_test_input_bin)) {
TF_LITE_REPORT_ERROR(&error_reporter, "Failed to load test input!");
vTaskSuspend(nullptr);
}
if (posenet_input->bytes != posenet_test_input_bin.size()) {
TF_LITE_REPORT_ERROR(&error_reporter,
"Input tensor length doesn't match canned input\r\n");
vTaskSuspend(nullptr);
}
memcpy(tflite::GetTensorData<uint8_t>(posenet_input),
posenet_test_input_bin.data(), posenet_test_input_bin.size());
if (interpreter.Invoke() != kTfLiteOk) {
TF_LITE_REPORT_ERROR(&error_reporter, "Invoke failed.");
vTaskSuspend(nullptr);
}
auto test_image_output =
tensorflow::GetPosenetOutput(&interpreter, /*threshold=*/0.5);
printf("%s\r\n", tensorflow::FormatPosenetOutput(test_image_output).c_str());
// Starts the camera for live poses.
CameraTask::GetSingleton()->SetPower(true);
CameraTask::GetSingleton()->Enable(CameraMode::kStreaming);
printf("Starting live posenet\r\n");
auto model_height = posenet_input->dims->data[1];
auto model_width = posenet_input->dims->data[2];
for (;;) {
CameraFrameFormat fmt{
/*fmt=*/CameraFormat::kRgb,
/*filter=*/CameraFilterMethod::kBilinear,
/*rotation=*/CameraRotation::k270,
/*width=*/model_width,
/*height=*/model_height,
/*preserve_ratio=*/false,
/*buffer=*/tflite::GetTensorData<uint8_t>(posenet_input)};
if (!CameraTask::GetSingleton()->GetFrame({fmt})) {
TF_LITE_REPORT_ERROR(&error_reporter, "Failed to get image from camera.");
break;
}
if (interpreter.Invoke() != kTfLiteOk) {
TF_LITE_REPORT_ERROR(&error_reporter, "Invoke failed.");
break;
}
auto output = tensorflow::GetPosenetOutput(&interpreter,
/*threshold=*/0.5);
printf("%s\r\n", tensorflow::FormatPosenetOutput(output).c_str());
vTaskDelay(pdMS_TO_TICKS(100));
}
CameraTask::GetSingleton()->SetPower(false);
}
} // namespace
} // namespace coralmicro
extern "C" void app_main(void* param) {
(void)param;
coralmicro::Main();
vTaskSuspend(nullptr);
}
-
namespace
coralmicro
Functions
-
TfLiteRegistration *
RegisterPosenetDecoderOp
()¶ Returns pointer to an instance of
tflite::TfLiteRegistration
to handle the custom op for post-processing PoseNet output tensors on the MCU. Pass this totflite::MicroMutableOpResolver::AddCustom()
.
Variables
-
constexpr char
kPosenetDecoderOp
[] = "PosenetDecoderOp"¶ PoseNet custom op name. Pass this to
tflite::MicroMutableOpResolver::AddCustom()
.
-
TfLiteRegistration *
-
namespace
coralmicro
-
namespace
tensorflow
Functions
-
std::string
FormatPosenetOutput
(const std::vector<Pose> &poses)¶ Formats all the PoseNet output into a string.
- Parameters
poses – A vector contains all the poses in a posenet output.
- Returns
A string showing the posenet’s output.
-
std::vector<Pose>
GetPosenetOutput
(tflite::MicroInterpreter *interpreter, float threshold = -std::numeric_limits<float>::infinity())¶ Gets the results from a PoseNet model in the form of a vector of poses.
After you invoke the interpreter, pass it to this function to get structured pose results.
- Parameters
interpreter – The already-invoked interpreter for your PoseNet model.
threshold – The overall pose score threshold for results.
- Returns
All detected poses with an overall score greater-than-or-equal-to the threshold.
Variables
-
constexpr int
kKeypoints
= 17¶ Number of keypoints in each pose.
-
const char *const
KeypointTypes
[] = {"NOSE", "LEFT_EYE", "RIGHT_EYE", "LEFT_EAR", "RIGHT_EAR", "LEFT_SHOULDER", "RIGHT_SHOULDER", "LEFT_ELBOW", "RIGHT_ELBOW", "LEFT_WRIST", "RIGHT_WRIST", "LEFT_HIP", "RIGHT_HIP", "LEFT_KNEE", "RIGHT_KNEE", "LEFT_ANKLE", "RIGHT_ANKLE",}¶ A map of keypoint index to the keypoint name.
-
struct
Keypoint
¶ - #include <posenet.h>
The location and score of a pose keypoint.
-
struct
Pose
¶ - #include <posenet.h>
Represents an individual pose.
-
std::string
-
namespace
Audio Classification
The following APIs assist with running audio classification models on the Dev Board Micro, either on CPU or Edge TPU. For supported models, see the audio classification models.
For an example, see examples/classify_speech/
.
-
namespace
coralmicro
-
namespace
tensorflow
Enums
Functions
-
template<bool
tForTpu
>
autoSetupYamNetResolver
()¶ Sets up the MicroMutableOpResolver with ops required for YamNet.
- Template Parameters
tForTpu – If true the Resolver will be setup for TPU else CPU.
- Returns
A tflite::MicroMutableOpResolver that is prepared for the YamNet model.
-
bool
PrepareAudioFrontEnd
(FrontendState *frontend_state, AudioModel model_type)¶ Prepares the input preprocess engine for TensorFlow to converts raw audio data to spectrogram. This function must be called before
PreprocessAudioInput()
is called.- Parameters
frontend_state – The FrontendState struct to populate.
model_type – The type of audio model.
- Returns
true on
FrontendPopulateState()
success, else false.
-
void
YamNetPreprocessInput
(const int16_t *audio_data, TfLiteTensor *input_tensor, FrontendState *frontend_state)¶ Performs input preprocessing to convert raw input to spectrogram.
- Parameters
audio_data – An array of signed int16 audio data.
input_tensor – The tensor where the preprocessed spectrogram data is stored.
frontend_state – The populated frontend state that you want to preprocess the input tensor, must not be nullptr.
-
void
KeywordDetectorPreprocessInput
(const int16_t *audio_data, TfLiteTensor *input_tensor, FrontendState *frontend_state)¶ Performs input preprocessing to convert raw audio input to spectrogram.
- Parameters
audio_data – An array of signed int16 audio data.
input_tensor – The tensor you want to pre-process for a TensorFlow model, must not be nullptr.
frontend_state – The populated frontend state that you want to preprocess the input tensor, must not be nullptr.
Variables
-
constexpr int
kYamnetSampleRate
= 16000¶
-
constexpr int
kYamnetSampleRateMs
= kYamnetSampleRate / 1000¶
-
constexpr int
kYamnetDurationMs
= 975¶
-
constexpr int
kYamnetAudioSize
= kYamnetSampleRate * kYamnetDurationMs / 1000¶
-
constexpr int
kYamnetFeatureSliceSize
= 64¶
-
constexpr int
kYamnetFeatureSliceCount
= 96¶
-
constexpr int
kYamnetFeatureElementCount
= (kYamnetFeatureSliceSize * kYamnetFeatureSliceCount)¶
-
constexpr int
kYamnetFeatureSliceStrideMs
= 10¶
-
constexpr int
kYamnetFeatureSliceDurationMs
= 25¶
-
constexpr int
kKeywordDetectorSampleRate
= 16000¶
-
constexpr int
kKeywordDetectorSampleRateMs
= kKeywordDetectorSampleRate / 1000¶
-
constexpr int
kKeywordDetectorDurationMs
= 2000¶
-
constexpr int
kKeywordDetectorAudioSize
= kKeywordDetectorSampleRate * kKeywordDetectorDurationMs / 1000¶
-
constexpr int
kKeywordDetectorFeatureSliceSize
= 32¶
-
constexpr int
kKeywordDetectorFeatureSliceCount
= 198¶
-
constexpr int
kKeywordDetectorFeatureElementCount
= (kKeywordDetectorFeatureSliceSize * kKeywordDetectorFeatureSliceCount)¶
-
constexpr int
kKeywordDetectorFeatureSliceStrideMs
= 10¶
-
constexpr int
kKeywordDetectorFeatureSliceDurationMs
= 25¶
-
template<bool
-
namespace
Utilities
The following functions help with some common tasks during inferencing, such as manipulate images and tensors.
-
namespace
coralmicro
-
namespace
tensorflow
-
struct
ImageDims
¶ - #include <utils.h>
Represents the dimensions of an image.
Functions
-
inline bool
operator==
(const ImageDims &a, const ImageDims &b)¶ Operator == to compares 2 ImageDims object.
-
bool
ResizeImage
(const ImageDims &in_dims, const uint8_t *uin, const ImageDims &out_dims, uint8_t *uout)¶ Resizes a bitmap image.
- Parameters
in_dims – The current dimensions for image
uin
.uin – The input image location.
out_dims – The desired dimensions for image
uout
.uout – The output image location.
-
inline int
TensorSize
(TfLiteTensor *tensor)¶ Gets the size of a tensor.
- Parameters
tensor – The tensor to get the size.
- Returns
The size of the tensor.
-
template<typename
I
, typenameO
>
voidDequantize
(int tensor_size, I *tensor_data, O *dequant_data, float scale, float zero_point)¶ Dequantizes data.
- Parameters
tensor_size – The tensor’s size.
tensor_data – The tensor’s data.
dequant_data – The buffer to return the dequantized data to.
scale – The scale of the input tensor.
zero_point – The zero point of the input tensor.
- Template Parameters
I – The data type of tensor_data.
O – The desired data type of the dequantized data. Note: You should instead use
DequantizeTensor()
.
-
template<typename
T
>
std::vector<T>DequantizeTensor
(TfLiteTensor *tensor)¶ Dequantizes a tensor.
- Parameters
tensor – The tensor to dequantize.
- Template Parameters
T – The desired output type of the dequantized data. When using a model adapter API such as
GetClassificationResults()
, this dequantization is done for you.- Returns
A vector of quantized data.
-
struct
-
namespace
Defines
-
STATIC_TENSOR_ARENA_IN_SDRAM
(name, size)¶ Allocates a uint8_t tensor arena statically in the Dev Board Micro SDRAM (max of 64 MB). This is slightly slower than OCRAM due to off-chip I/O overhead costs.
- Parameters
name – The variable name for this allocation.
size – The byte size to allocate. This macro automatically aligns the size to 16 bits.
-
STATIC_TENSOR_ARENA_IN_OCRAM
(name, size)¶ Allocates a uint8_t tensor arena statically in the RT1176 on-chip RAM (max of 1.25 MB).
- Parameters
name – The variable name for this allocation.
size – The byte size to allocate. This macro automatically aligns the size to 16 bits.
Is this content helpful?