ESPHome 2026.6.0-dev
Loading...
Searching...
No Matches
streaming_model.cpp
Go to the documentation of this file.
1#include "streaming_model.h"
2
3#ifdef USE_ESP32
4
6#include "esphome/core/log.h"
7
8static const char *const TAG = "micro_wake_word";
9
11
13 ESP_LOGCONFIG(TAG,
14 " - Wake Word: %s\n"
15 " Probability cutoff: %.2f\n"
16 " Sliding window size: %d",
17 this->wake_word_.c_str(), this->probability_cutoff_ / 255.0f, this->sliding_window_size_);
18}
19
21 ESP_LOGCONFIG(TAG,
22 " - VAD Model\n"
23 " Probability cutoff: %.2f\n"
24 " Sliding window size: %d",
25 this->probability_cutoff_ / 255.0f, this->sliding_window_size_);
26}
27
29 RAMAllocator<uint8_t> arena_allocator;
30
31 if (this->var_arena_ == nullptr) {
32 this->var_arena_ = arena_allocator.allocate(STREAMING_MODEL_VARIABLE_ARENA_SIZE);
33 if (this->var_arena_ == nullptr) {
34 ESP_LOGE(TAG, "Could not allocate the streaming model's variable tensor arena.");
35 return false;
36 }
37 this->ma_ = tflite::MicroAllocator::Create(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
38 this->mrv_ = tflite::MicroResourceVariables::Create(this->ma_, 20);
39 }
40
41 const tflite::Model *model = tflite::GetModel(this->model_start_);
42 if (model->version() != TFLITE_SCHEMA_VERSION) {
43 ESP_LOGE(TAG, "Streaming model's schema is not supported");
44 return false;
45 }
46
47 // Probe for the actual required tensor arena size if not yet determined
48 if (!this->tensor_arena_size_probed_) {
49 size_t probed_size = this->probe_arena_size_();
50 if (probed_size > 0) {
51 ESP_LOGD(TAG, "Probed tensor arena size: %zu bytes", probed_size);
52 this->tensor_arena_size_ = probed_size;
53 } else {
54 ESP_LOGW(TAG, "Arena size probe failed, using manifest size: %zu bytes", this->tensor_arena_size_);
55 }
56 this->tensor_arena_size_probed_ = true;
57 }
58
59 if (this->tensor_arena_ == nullptr) {
60 this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_);
61 if (this->tensor_arena_ == nullptr) {
62 ESP_LOGE(TAG, "Could not allocate the streaming model's tensor arena.");
63 return false;
64 }
65 }
66
67 if (this->interpreter_ == nullptr) {
68 this->interpreter_ =
69 make_unique<tflite::MicroInterpreter>(tflite::GetModel(this->model_start_), this->streaming_op_resolver_,
70 this->tensor_arena_, this->tensor_arena_size_, this->mrv_);
71 if (this->interpreter_->AllocateTensors() != kTfLiteOk) {
72 ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model");
73 return false;
74 }
75
76 // Verify input tensor matches expected values
77 // Dimension 3 will represent the first layer stride, so skip it may vary
78 TfLiteTensor *input = this->interpreter_->input(0);
79 if ((input->dims->size != 3) || (input->dims->data[0] != 1) ||
80 (input->dims->data[2] != PREPROCESSOR_FEATURE_SIZE)) {
81 ESP_LOGE(TAG, "Streaming model tensor input dimensions has improper dimensions.");
82 return false;
83 }
84
85 if (input->type != kTfLiteInt8) {
86 ESP_LOGE(TAG, "Streaming model tensor input is not int8.");
87 return false;
88 }
89
90 // Verify output tensor matches expected values
91 TfLiteTensor *output = this->interpreter_->output(0);
92 if ((output->dims->size != 2) || (output->dims->data[0] != 1) || (output->dims->data[1] != 1)) {
93 ESP_LOGE(TAG, "Streaming model tensor output dimension is not 1x1.");
94 return false;
95 }
96
97 if (output->type != kTfLiteUInt8) {
98 ESP_LOGE(TAG, "Streaming model tensor output is not uint8.");
99 return false;
100 }
101 }
102
103 this->loaded_ = true;
104 this->reset_probabilities();
105 return true;
106}
107
109 RAMAllocator<uint8_t> arena_allocator;
110
111 // Try with the manifest size first, then escalates to 1.5, then 2x if it fails. Different platforms and different
112 // versions of the esp-nn library require different amounts of memory, so the manifest size may not always be correct,
113 // and probing allows us to find the actual required size for the current build and platform. Aligns test sizes to 16
114 // bytes.
115 size_t attempt_sizes[] = {(this->tensor_arena_size_ + 15) & ~15, (this->tensor_arena_size_ * 3 / 2 + 15) & ~15,
116 (this->tensor_arena_size_ * 2 + 15) & ~15};
117
118 for (size_t attempt_size : attempt_sizes) {
119 uint8_t *probe_arena = arena_allocator.allocate(attempt_size);
120 if (probe_arena == nullptr) {
121 continue;
122 }
123
124 // Verify the model works at all with this arena size
125 auto probe_interpreter = make_unique<tflite::MicroInterpreter>(
126 tflite::GetModel(this->model_start_), this->streaming_op_resolver_, probe_arena, attempt_size, this->mrv_);
127
128 if (probe_interpreter->AllocateTensors() != kTfLiteOk) {
129 probe_interpreter.reset();
130 arena_allocator.deallocate(probe_arena, attempt_size);
131 this->ma_ = tflite::MicroAllocator::Create(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
132 this->mrv_ = tflite::MicroResourceVariables::Create(this->ma_, 20);
133 continue;
134 }
135
136 // Try to shrink the arena. Start with arena_used_bytes() + 16 (rounded to 16-byte alignment).
137 // If that works, use it. Otherwise, try midpoints between that and the full size until one succeeds.
138 size_t lower = (probe_interpreter->arena_used_bytes() + 16 + 15) & ~15;
139 probe_interpreter.reset();
140 this->ma_ = tflite::MicroAllocator::Create(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
141 this->mrv_ = tflite::MicroResourceVariables::Create(this->ma_, 20);
142
143 size_t upper = attempt_size;
144
145 while (lower < upper) {
146 auto test_interpreter = make_unique<tflite::MicroInterpreter>(
147 tflite::GetModel(this->model_start_), this->streaming_op_resolver_, probe_arena, lower, this->mrv_);
148
149 bool ok = test_interpreter->AllocateTensors() == kTfLiteOk;
150
151 test_interpreter.reset();
152 this->ma_ = tflite::MicroAllocator::Create(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
153 this->mrv_ = tflite::MicroResourceVariables::Create(this->ma_, 20);
154
155 if (ok) {
156 // Found a working size smaller than the full arena
157 upper = lower + 16; // Pad by 16 bytes to be safe for future allocations
158 break;
159 }
160
161 // Try the midpoint between current attempt and full size
162 lower = ((lower + upper) / 2 + 15) & ~15;
163 }
164
165 arena_allocator.deallocate(probe_arena, attempt_size);
166 return upper;
167 }
168
169 return 0;
170}
171
173 this->interpreter_.reset();
174
175 RAMAllocator<uint8_t> arena_allocator;
176
177 if (this->tensor_arena_ != nullptr) {
178 arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_);
179 this->tensor_arena_ = nullptr;
180 }
181
182 if (this->var_arena_ != nullptr) {
183 arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
184 this->var_arena_ = nullptr;
185 }
186
187 this->loaded_ = false;
188}
189
190bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) {
191 if (this->enabled_ && !this->loaded_) {
192 // Model is enabled but isn't loaded
193 if (!this->load_model_()) {
194 return false;
195 }
196 }
197
198 if (!this->enabled_ && this->loaded_) {
199 // Model is disabled but still loaded
200 this->unload_model();
201 return true;
202 }
203
204 if (this->loaded_) {
205 TfLiteTensor *input = this->interpreter_->input(0);
206
207 uint8_t stride = this->interpreter_->input(0)->dims->data[1];
208 this->current_stride_step_ = this->current_stride_step_ % stride;
209
210 std::memmove(
211 (int8_t *) (tflite::GetTensorData<int8_t>(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_,
212 features, PREPROCESSOR_FEATURE_SIZE);
213 ++this->current_stride_step_;
214
215 if (this->current_stride_step_ >= stride) {
216 TfLiteStatus invoke_status = this->interpreter_->Invoke();
217 if (invoke_status != kTfLiteOk) {
218 ESP_LOGW(TAG, "Streaming interpreter invoke failed");
219 return false;
220 }
221
222 TfLiteTensor *output = this->interpreter_->output(0);
223
224 ++this->last_n_index_;
225 if (this->last_n_index_ == this->sliding_window_size_)
226 this->last_n_index_ = 0;
227 this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0]; // probability;
229 }
231 // Only increment ignore windows if less than the probability cutoff; this forces the model to "cool-off" from a
232 // previous detection and calling ``reset_probabilities`` so it avoids duplicate detections
233 this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0);
234 }
235 }
236 return true;
237}
238
240 for (auto &prob : this->recent_streaming_probabilities_) {
241 prob = 0;
242 }
243 this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
244}
245
246WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff,
247 size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
248 bool default_enabled, bool internal_only) {
249 this->id_ = id;
250 this->model_start_ = model_start;
251 this->default_probability_cutoff_ = default_probability_cutoff;
252 this->probability_cutoff_ = default_probability_cutoff;
253 this->sliding_window_size_ = sliding_window_average_size;
254 this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0);
255 this->wake_word_ = wake_word;
256 this->tensor_arena_size_ = tensor_arena_size;
258 this->current_stride_step_ = 0;
259 this->internal_only_ = internal_only;
260
262 bool enabled;
263 if (this->pref_.load(&enabled)) {
264 // Use the enabled state loaded from flash
265 this->enabled_ = enabled;
266 } else {
267 // If no state saved, then use the default
268 this->enabled_ = default_enabled;
269 }
270};
271
273 this->enabled_ = true;
274 if (!this->internal_only_) {
275 this->pref_.save(&this->enabled_);
276 }
277}
278
280 this->enabled_ = false;
281 if (!this->internal_only_) {
282 this->pref_.save(&this->enabled_);
283 }
284}
285
287 DetectionEvent detection_event;
288 detection_event.wake_word = &this->wake_word_;
289 detection_event.max_probability = 0;
290 detection_event.average_probability = 0;
291
292 if ((this->ignore_windows_ < 0) || !this->enabled_) {
293 detection_event.detected = false;
294 return detection_event;
295 }
296
297 uint32_t sum = 0;
298 for (auto &prob : this->recent_streaming_probabilities_) {
299 detection_event.max_probability = std::max(detection_event.max_probability, prob);
300 sum += prob;
301 }
302
303 detection_event.average_probability = sum / this->sliding_window_size_;
304 detection_event.detected = sum > this->probability_cutoff_ * this->sliding_window_size_;
305
307 return detection_event;
308}
309
310VADModel::VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size,
311 size_t tensor_arena_size) {
312 this->model_start_ = model_start;
313 this->default_probability_cutoff_ = default_probability_cutoff;
314 this->probability_cutoff_ = default_probability_cutoff;
315 this->sliding_window_size_ = sliding_window_size;
316 this->recent_streaming_probabilities_.resize(sliding_window_size, 0);
317 this->tensor_arena_size_ = tensor_arena_size;
319}
320
322 DetectionEvent detection_event;
323 detection_event.max_probability = 0;
324 detection_event.average_probability = 0;
325
326 if (!this->enabled_) {
327 // We disabled the VAD model for some reason... so we shouldn't block wake words from being detected
328 detection_event.detected = true;
329 return detection_event;
330 }
331
332 uint32_t sum = 0;
333 for (auto &prob : this->recent_streaming_probabilities_) {
334 detection_event.max_probability = std::max(detection_event.max_probability, prob);
335 sum += prob;
336 }
337
338 detection_event.average_probability = sum / this->sliding_window_size_;
339 detection_event.detected = sum > (this->probability_cutoff_ * this->sliding_window_size_);
340
341 return detection_event;
342}
343
344bool StreamingModel::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) {
345 if (op_resolver.AddCallOnce() != kTfLiteOk)
346 return false;
347 if (op_resolver.AddVarHandle() != kTfLiteOk)
348 return false;
349 if (op_resolver.AddReshape() != kTfLiteOk)
350 return false;
351 if (op_resolver.AddReadVariable() != kTfLiteOk)
352 return false;
353 if (op_resolver.AddStridedSlice() != kTfLiteOk)
354 return false;
355 if (op_resolver.AddConcatenation() != kTfLiteOk)
356 return false;
357 if (op_resolver.AddAssignVariable() != kTfLiteOk)
358 return false;
359 if (op_resolver.AddConv2D() != kTfLiteOk)
360 return false;
361 if (op_resolver.AddMul() != kTfLiteOk)
362 return false;
363 if (op_resolver.AddAdd() != kTfLiteOk)
364 return false;
365 if (op_resolver.AddMean() != kTfLiteOk)
366 return false;
367 if (op_resolver.AddFullyConnected() != kTfLiteOk)
368 return false;
369 if (op_resolver.AddLogistic() != kTfLiteOk)
370 return false;
371 if (op_resolver.AddQuantize() != kTfLiteOk)
372 return false;
373 if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk)
374 return false;
375 if (op_resolver.AddAveragePool2D() != kTfLiteOk)
376 return false;
377 if (op_resolver.AddMaxPool2D() != kTfLiteOk)
378 return false;
379 if (op_resolver.AddPad() != kTfLiteOk)
380 return false;
381 if (op_resolver.AddPack() != kTfLiteOk)
382 return false;
383 if (op_resolver.AddSplitV() != kTfLiteOk)
384 return false;
385
386 return true;
387}
388
389} // namespace esphome::micro_wake_word
390
391#endif
An STL allocator that uses SPI or internal RAM.
Definition helpers.h:2053
void deallocate(T *p, size_t n)
Definition helpers.h:2110
T * allocate(size_t n)
Definition helpers.h:2080
bool load_model_()
Allocates tensor and variable arenas and sets up the model interpreter.
std::unique_ptr< tflite::MicroInterpreter > interpreter_
tflite::MicroMutableOpResolver< 20 > streaming_op_resolver_
bool register_streaming_ops_(tflite::MicroMutableOpResolver< 20 > &op_resolver)
Returns true if successfully registered the streaming model's TensorFlow operations.
void reset_probabilities()
Sets all recent_streaming_probabilities to 0 and resets the ignore window count.
std::vector< uint8_t > recent_streaming_probabilities_
size_t probe_arena_size_()
Probes the actual required tensor arena size by trial allocation.
tflite::MicroResourceVariables * mrv_
bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE])
void unload_model()
Destroys the TFLite interpreter and frees the tensor and variable arenas' memory.
DetectionEvent determine_detected() override
Checks for voice activity by comparing the max probability in the sliding window with the probability...
VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size)
void enable() override
Enable the model and save to flash. The next performing_streaming_inference call will load it.
DetectionEvent determine_detected() override
Checks for the wake word by comparing the mean probability in the sliding window with the probability...
WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, bool default_enabled, bool internal_only)
Constructs a wake word model object.
void disable() override
Disable the model and save to flash. The next performing_streaming_inference call will unload it.
uint16_t id
ESPPreferences * global_preferences
uint32_t fnv1_hash(const char *str)
Calculate a FNV-1 hash of str.
Definition helpers.cpp:160
static void uint32_t
ESPPreferenceObject make_preference(size_t, uint32_t, bool)
Definition preferences.h:24