ESPHome 2026.6.0-dev
Loading...
Searching...
No Matches
micro_wake_word.cpp
Go to the documentation of this file.
1#include "micro_wake_word.h"
2
3#ifdef USE_ESP32
4
6#include "esphome/core/hal.h"
8#include "esphome/core/log.h"
9
11
12#ifdef USE_OTA
14#endif
15
17
18static const char *const TAG = "micro_wake_word";
19
20static const ssize_t DETECTION_QUEUE_LENGTH = 5;
21
22static const size_t DATA_TIMEOUT_MS = 50;
23
24static const uint32_t RING_BUFFER_DURATION_MS = 120;
25
26#ifdef CONFIG_IDF_TARGET_ESP32P4
27// ESP32-P4 PIE-optimized esp-nn kernels (e.g. depthwise_conv_s8_ch1_pie) require
28// significantly more stack than other variants, causing stack protection faults at 3072.
29static const uint32_t INFERENCE_TASK_STACK_SIZE = 8192;
30#else
31static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072;
32#endif
33static const UBaseType_t INFERENCE_TASK_PRIORITY = 3;
34
36 COMMAND_STOP = (1 << 0), // Signals the inference task should stop
37 COMMAND_RESET_RING_BUFFER = (1 << 1), // Signals the inference task to discard buffered audio
38
39 TASK_STARTING = (1 << 3),
40 TASK_RUNNING = (1 << 4),
41 TASK_STOPPING = (1 << 5),
42 TASK_STOPPED = (1 << 6),
43
44 ERROR_MEMORY = (1 << 9),
45 ERROR_INFERENCE = (1 << 10),
46
48
50 ALL_BITS = 0xfffff, // 24 total bits available in an event group
51};
52
54
55static const LogString *micro_wake_word_state_to_string(State state) {
56 switch (state) {
57 case State::STARTING:
58 return LOG_STR("STARTING");
60 return LOG_STR("DETECTING_WAKE_WORD");
61 case State::STOPPING:
62 return LOG_STR("STOPPING");
63 case State::STOPPED:
64 return LOG_STR("STOPPED");
65 default:
66 return LOG_STR("UNKNOWN");
67 }
68}
69
71 ESP_LOGCONFIG(TAG, "microWakeWord:");
72 ESP_LOGCONFIG(TAG, " models:");
73 for (auto &model : this->wake_word_models_) {
74 model->log_model_config();
75 }
76#ifdef USE_MICRO_WAKE_WORD_VAD
77 this->vad_model_->log_model_config();
78#endif
79}
80
82 this->frontend_config_.window.size_ms = FEATURE_DURATION_MS;
83 this->frontend_config_.window.step_size_ms = this->features_step_size_;
84 this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE;
85 this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT;
86 this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT;
87 this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS;
88 this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING;
89 this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING;
90 this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING;
91 this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN;
92 this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH;
93 this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET;
94 this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS;
95 this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG;
96 this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT;
97
98 this->event_group_ = xEventGroupCreate();
99 if (this->event_group_ == nullptr) {
100 ESP_LOGE(TAG, "Failed to create event group");
101 this->mark_failed();
102 return;
103 }
104
105 this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent));
106 if (this->detection_queue_ == nullptr) {
107 ESP_LOGE(TAG, "Failed to create detection event queue");
108 this->mark_failed();
109 return;
110 }
111
112 this->microphone_source_->add_data_callback([this](const std::vector<uint8_t> &data) {
113 if (this->state_ == State::STOPPED) {
114 return;
115 }
116 std::shared_ptr<ring_buffer::RingBuffer> temp_ring_buffer = this->ring_buffer_.lock();
117 if (this->ring_buffer_.use_count() > 1) {
118 // Producer-only write: never touches consumer state. If the buffer is full, ask the inference task
119 // to drain it - reset() is a consumer operation and must run on the inference task's thread.
120 // Disable partial writes so audio chunks are either fully accepted or rejected and handled below.
121 if (temp_ring_buffer->write_without_replacement(data.data(), data.size(), 0, false) == 0) {
122 xEventGroupSetBits(this->event_group_,
124 }
125 }
126 });
127
128#ifdef USE_OTA_STATE_LISTENER
130#endif
131}
132
133#ifdef USE_OTA_STATE_LISTENER
134void MicroWakeWord::on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
135 if (state == ota::OTA_STARTED) {
136 this->suspend_task_();
137 } else if (state == ota::OTA_ERROR) {
138 this->resume_task_();
139 }
140}
141#endif
142
144 MicroWakeWord *this_mww = (MicroWakeWord *) params;
145
146 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING);
147
148 { // Ensures any C++ objects fall out of scope to deallocate before deleting the task
149
150 const auto &stream_info = this_mww->microphone_source_->get_audio_stream_info();
151 const size_t bytes_per_frame = stream_info.frames_to_bytes(1);
152 const size_t max_fill_bytes = stream_info.ms_to_bytes(this_mww->features_step_size_);
153 std::unique_ptr<audio::RingBufferAudioSource> audio_source;
154 int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE];
155
156 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
157 // Round ring buffer size down to a frame multiple so the wrap boundary never splits an int16 sample.
158 const size_t ring_buffer_size =
159 (stream_info.ms_to_bytes(RING_BUFFER_DURATION_MS) / bytes_per_frame) * bytes_per_frame;
160 std::shared_ptr<ring_buffer::RingBuffer> temp_ring_buffer = ring_buffer::RingBuffer::create(ring_buffer_size);
161 if (temp_ring_buffer == nullptr) {
162 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
163 } else {
164 audio_source = audio::RingBufferAudioSource::create(temp_ring_buffer, max_fill_bytes,
165 static_cast<uint8_t>(bytes_per_frame));
166 if (audio_source == nullptr) {
167 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
168 } else {
169 this_mww->ring_buffer_ = temp_ring_buffer;
170 }
171 }
172 }
173
174 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
175 this_mww->microphone_source_->start();
176 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING);
177
178 while (!(xEventGroupGetBits(this_mww->event_group_) & (COMMAND_STOP | ERROR_BITS))) {
179 if (xEventGroupGetBits(this_mww->event_group_) & EventGroupBits::COMMAND_RESET_RING_BUFFER) {
180 // Producer asked us to drain; run the consumer-side reset from this thread.
181 audio_source->clear_buffered_data();
182 xEventGroupClearBits(this_mww->event_group_, EventGroupBits::COMMAND_RESET_RING_BUFFER);
183 }
184
185 audio_source->fill(pdMS_TO_TICKS(DATA_TIMEOUT_MS), false);
186
187 // The frontend buffers samples internally and only emits a feature once it has a full window, so we can
188 // hand it whatever the source exposes. The frontend consumes at least one sample per call, so available()
189 // strictly decreases and this loop always terminates.
190 while (audio_source->available() >= sizeof(int16_t)) {
191 const size_t samples_available = audio_source->available() / sizeof(int16_t);
192 const int16_t *audio_data = reinterpret_cast<const int16_t *>(audio_source->data());
193
194 size_t processed_samples = 0;
195 const bool feature_generated =
196 this_mww->generate_features_(audio_data, samples_available, features_buffer, &processed_samples);
197 audio_source->consume(processed_samples * sizeof(int16_t));
198
199 if (feature_generated) {
200 if (!this_mww->update_model_probabilities_(features_buffer)) {
201 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE);
202 break;
203 }
204
205 // Process each model's probabilities and possibly send a Detection Event to the queue
206 this_mww->process_probabilities_();
207 }
208 }
209 }
210 }
211 }
212
213 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING);
214
215 this_mww->unload_models_();
216 this_mww->microphone_source_->stop();
217 FrontendFreeStateContents(&this_mww->frontend_state_);
218
219 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED);
220 vTaskSuspend(nullptr); // Suspend this task indefinitely until the loop method deletes it
221}
222
223std::vector<WakeWordModel *> MicroWakeWord::get_wake_words() {
224 std::vector<WakeWordModel *> external_wake_word_models;
225 for (auto *model : this->wake_word_models_) {
226 if (!model->get_internal_only()) {
227 external_wake_word_models.push_back(model);
228 }
229 }
230 return external_wake_word_models;
231}
232
234
235#ifdef USE_MICRO_WAKE_WORD_VAD
236void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
237 size_t tensor_arena_size) {
238 this->vad_model_ = make_unique<VADModel>(model_start, probability_cutoff, sliding_window_size, tensor_arena_size);
239}
240#endif
241
243 if (this->inference_task_.is_created()) {
244 vTaskSuspend(this->inference_task_.get_handle());
245 }
246}
247
249 if (this->inference_task_.is_created()) {
250 vTaskResume(this->inference_task_.get_handle());
251 }
252}
253
255 uint32_t event_group_bits = xEventGroupGetBits(this->event_group_);
256
257 if (event_group_bits & EventGroupBits::ERROR_MEMORY) {
258 xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY);
259 ESP_LOGE(TAG, "Encountered an error allocating buffers");
260 }
261
262 if (event_group_bits & EventGroupBits::ERROR_INFERENCE) {
263 xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE);
264 ESP_LOGE(TAG, "Encountered an error while performing an inference");
265 }
266
267 if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) {
268 xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER);
269 ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake "
270 "word detection accuracy will temporarily be reduced.");
271 }
272
273 if (event_group_bits & EventGroupBits::TASK_STARTING) {
274 ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers");
275 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING);
276 }
277
278 if (event_group_bits & EventGroupBits::TASK_RUNNING) {
279 ESP_LOGD(TAG, "Inference task is running");
280
281 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING);
283 }
284
285 if (event_group_bits & EventGroupBits::TASK_STOPPING) {
286 ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers");
287 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING);
288 }
289
290 if ((event_group_bits & EventGroupBits::TASK_STOPPED)) {
291 ESP_LOGD(TAG, "Inference task is finished, freeing task resources");
293 xEventGroupClearBits(this->event_group_, ALL_BITS);
294 xQueueReset(this->detection_queue_);
296 }
297
298 if ((this->pending_start_) && (this->state_ == State::STOPPED)) {
300 this->pending_start_ = false;
301 }
302
303 if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) {
305 this->pending_stop_ = false;
306 }
307
308 switch (this->state_) {
309 case State::STARTING:
310 if (!this->inference_task_.is_created() && !this->status_has_error()) {
311 // Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it
312 // uses floating point operations.
313 if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_,
315 this->status_momentary_error("frontend_alloc", 1000);
316 return;
317 }
318
319 if (!this->inference_task_.create(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE,
320 (void *) this, INFERENCE_TASK_PRIORITY, this->task_stack_in_psram_)) {
321 FrontendFreeStateContents(&this->frontend_state_); // Deallocate frontend state
322 this->status_momentary_error("task_start", 1000);
323 }
324 }
325 break;
327 DetectionEvent detection_event;
328 while (xQueueReceive(this->detection_queue_, &detection_event, 0)) {
329 if (detection_event.blocked_by_vad) {
330 ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str());
331 } else {
332 constexpr float uint8_to_float_divisor =
333 255.0f; // Converting a quantized uint8 probability to floating point
334 ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f",
335 detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor),
336 (detection_event.max_probability / uint8_to_float_divisor));
337 this->wake_word_detected_trigger_.trigger(*detection_event.wake_word);
338 if (this->stop_after_detection_) {
339 this->stop();
340 }
341 }
342 }
343 break;
344 }
345 case State::STOPPING:
346 xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP);
347 break;
348 case State::STOPPED:
349 break;
350 }
351}
352
354 if (!this->is_ready()) {
355 ESP_LOGW(TAG, "Wake word detection can't start as the component hasn't been setup yet");
356 return;
357 }
358
359 if (this->is_failed()) {
360 ESP_LOGW(TAG, "Wake word component is marked as failed. Please check setup logs");
361 return;
362 }
363
364 if (this->is_running()) {
365 ESP_LOGW(TAG, "Wake word detection is already running");
366 return;
367 }
368
369 ESP_LOGD(TAG, "Starting wake word detection");
370
371 this->pending_start_ = true;
372 this->pending_stop_ = false;
373}
374
376 if (this->state_ == STOPPED)
377 return;
378
379 ESP_LOGD(TAG, "Stopping wake word detection");
380
381 this->pending_start_ = false;
382 this->pending_stop_ = true;
383}
384
386 if (this->state_ != state) {
387 ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)),
388 LOG_STR_ARG(micro_wake_word_state_to_string(state)));
389 this->state_ = state;
390 }
391}
392
393bool MicroWakeWord::generate_features_(const int16_t *audio_buffer, size_t samples_available,
394 int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE], size_t *processed_samples) {
395 *processed_samples = 0;
396 struct FrontendOutput frontend_output =
397 FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, processed_samples);
398
399 if (frontend_output.size == 0) {
400 return false;
401 }
402
403 for (size_t i = 0; i < frontend_output.size; ++i) {
404 // These scaling values are set to match the TFLite audio frontend int8 output.
405 // The feature pipeline outputs 16-bit signed integers in roughly a 0 to 670
406 // range. In training, these are then arbitrarily divided by 25.6 to get
407 // float values in the rough range of 0.0 to 26.0. This scaling is performed
408 // for historical reasons, to match up with the output of other feature
409 // generators.
410 // The process is then further complicated when we quantize the model. This
411 // means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN)
412 // to 127 (INT8_MAX) signed integer numbers.
413 // All this means that to get matching values from our integer feature
414 // output into the tensor input, we have to perform:
415 // input = (((feature / 25.6) / 26.0) * 256) - 128
416 // To simplify this and perform it in 32-bit integer math, we rearrange to:
417 // input = (feature * 256) / (25.6 * 26.0) - 128
418 constexpr int32_t value_scale = 256;
419 constexpr int32_t value_div = 666; // 666 = 25.6 * 26.0 after rounding
420 int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div;
421
422 value += INT8_MIN; // Adds a -128; i.e., subtracts 128
423 features_buffer[i] = static_cast<int8_t>(clamp<int32_t>(value, INT8_MIN, INT8_MAX));
424 }
425
426 return true;
427}
428
430#ifdef USE_MICRO_WAKE_WORD_VAD
431 DetectionEvent vad_state = this->vad_model_->determine_detected();
432
433 this->vad_state_ = vad_state.detected; // atomic write, so thread safe
434#endif
435
436 for (auto &model : this->wake_word_models_) {
437 if (model->get_unprocessed_probability_status()) {
438 // Only detect wake words if there is a new probability since the last check
439 DetectionEvent wake_word_state = model->determine_detected();
440 if (wake_word_state.detected) {
441#ifdef USE_MICRO_WAKE_WORD_VAD
442 if (vad_state.detected) {
443#endif
444 xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
445
446 // Wake main loop immediately to process wake word detection
448
449 model->reset_probabilities();
450#ifdef USE_MICRO_WAKE_WORD_VAD
451 } else {
452 wake_word_state.blocked_by_vad = true;
453 xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
454 }
455#endif
456 }
457 }
458 }
459}
460
462 for (auto &model : this->wake_word_models_) {
463 model->unload_model();
464 }
465#ifdef USE_MICRO_WAKE_WORD_VAD
466 this->vad_model_->unload_model();
467#endif
468}
469
470bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) {
471 bool success = true;
472
473 for (auto &model : this->wake_word_models_) {
474 // Perform inference
475 success = success & model->perform_streaming_inference(audio_features);
476 }
477#ifdef USE_MICRO_WAKE_WORD_VAD
478 success = success & this->vad_model_->perform_streaming_inference(audio_features);
479#endif
480
481 return success;
482}
483
484} // namespace esphome::micro_wake_word
485
486#endif // USE_ESP32
void wake_loop_threadsafe()
Wake the main event loop from another thread or callback.
void mark_failed()
Mark this component as failed.
void status_momentary_error(const char *name, uint32_t length=5000)
Set error status flag and automatically clear it after a timeout.
bool is_failed() const
Definition component.h:272
bool is_ready() const
bool create(TaskFunction_t fn, const char *name, uint32_t stack_size, void *param, UBaseType_t priority, bool use_psram)
Allocate stack and create task.
bool is_created() const
Check if the task has been created and not yet destroyed.
Definition static_task.h:18
TaskHandle_t get_handle() const
Get the FreeRTOS task handle.
Definition static_task.h:21
void deallocate()
Delete the task (if running) and free the stack buffer.
void trigger(const Ts &...x) ESPHOME_ALWAYS_INLINE
Inform the parent automation that the event has triggered.
Definition automation.h:461
size_t frames_to_bytes(uint32_t frames) const
Converts frames to bytes.
Definition audio.h:53
uint32_t get_sample_rate() const
Definition audio.h:30
static std::unique_ptr< RingBufferAudioSource > create(std::shared_ptr< ring_buffer::RingBuffer > ring_buffer, size_t max_fill_bytes, uint8_t alignment_bytes=1)
Creates a new ring-buffer-backed audio source after validating its parameters.
void resume_task_()
Resumes the inference task.
microphone::MicrophoneSource * microphone_source_
void process_probabilities_()
Processes any new probabilities for each model.
std::vector< WakeWordModel * > wake_word_models_
void suspend_task_()
Suspends the inference task.
Trigger< std::string > wake_word_detected_trigger_
void add_wake_word_model(WakeWordModel *model)
bool generate_features_(const int16_t *audio_buffer, size_t samples_available, int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE], size_t *processed_samples)
Generates a spectrogram feature from an input buffer of audio samples.
bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE])
Runs an inference with each model using the new spectrogram features.
std::unique_ptr< VADModel > vad_model_
std::weak_ptr< ring_buffer::RingBuffer > ring_buffer_
void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size)
void unload_models_()
Deletes each model's TFLite interpreters and frees tensor arena memory.
std::vector< WakeWordModel * > get_wake_words()
void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override
void add_data_callback(F &&data_callback)
audio::AudioStreamInfo get_audio_stream_info()
Gets the AudioStreamInfo of the data after processing.
void add_global_state_listener(OTAGlobalStateListener *listener)
static std::unique_ptr< RingBuffer > create(size_t len, MemoryPreference preference=MemoryPreference::EXTERNAL_FIRST)
bool state
Definition fan.h:2
__int64 ssize_t
Definition httplib.h:178
OTAGlobalCallback * get_global_ota_callback()
constexpr float AFTER_CONNECTION
For components that should be initialized after a data connection (API/MQTT) is connected.
Definition component.h:55
Application App
Global storage of Application pointer - only one Application can exist.
static void uint32_t