ESPHome 2026.1.0-dev
Loading...
Searching...
No Matches
micro_wake_word.cpp
Go to the documentation of this file.
1#include "micro_wake_word.h"
2
3#ifdef USE_ESP32
4
6#include "esphome/core/hal.h"
8#include "esphome/core/log.h"
9
11
12#ifdef USE_OTA
14#endif
15
16namespace esphome {
17namespace micro_wake_word {
18
19static const char *const TAG = "micro_wake_word";
20
21static const ssize_t DETECTION_QUEUE_LENGTH = 5;
22
23static const size_t DATA_TIMEOUT_MS = 50;
24
25static const uint32_t RING_BUFFER_DURATION_MS = 120;
26
27static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072;
28static const UBaseType_t INFERENCE_TASK_PRIORITY = 3;
29
30enum EventGroupBits : uint32_t {
31 COMMAND_STOP = (1 << 0), // Signals the inference task should stop
32
33 TASK_STARTING = (1 << 3),
34 TASK_RUNNING = (1 << 4),
35 TASK_STOPPING = (1 << 5),
36 TASK_STOPPED = (1 << 6),
37
38 ERROR_MEMORY = (1 << 9),
39 ERROR_INFERENCE = (1 << 10),
40
42
44 ALL_BITS = 0xfffff, // 24 total bits available in an event group
45};
46
48
49static const LogString *micro_wake_word_state_to_string(State state) {
50 switch (state) {
51 case State::STARTING:
52 return LOG_STR("STARTING");
54 return LOG_STR("DETECTING_WAKE_WORD");
55 case State::STOPPING:
56 return LOG_STR("STOPPING");
57 case State::STOPPED:
58 return LOG_STR("STOPPED");
59 default:
60 return LOG_STR("UNKNOWN");
61 }
62}
63
65 ESP_LOGCONFIG(TAG, "microWakeWord:");
66 ESP_LOGCONFIG(TAG, " models:");
67 for (auto &model : this->wake_word_models_) {
68 model->log_model_config();
69 }
70#ifdef USE_MICRO_WAKE_WORD_VAD
71 this->vad_model_->log_model_config();
72#endif
73}
74
76 this->frontend_config_.window.size_ms = FEATURE_DURATION_MS;
77 this->frontend_config_.window.step_size_ms = this->features_step_size_;
78 this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE;
79 this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT;
80 this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT;
81 this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS;
82 this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING;
83 this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING;
84 this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING;
85 this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN;
86 this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH;
87 this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET;
88 this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS;
89 this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG;
90 this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT;
91
92 this->event_group_ = xEventGroupCreate();
93 if (this->event_group_ == nullptr) {
94 ESP_LOGE(TAG, "Failed to create event group");
95 this->mark_failed();
96 return;
97 }
98
99 this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent));
100 if (this->detection_queue_ == nullptr) {
101 ESP_LOGE(TAG, "Failed to create detection event queue");
102 this->mark_failed();
103 return;
104 }
105
106 this->microphone_source_->add_data_callback([this](const std::vector<uint8_t> &data) {
107 if (this->state_ == State::STOPPED) {
108 return;
109 }
110 std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_.lock();
111 if (this->ring_buffer_.use_count() > 1) {
112 size_t bytes_free = temp_ring_buffer->free();
113
114 if (bytes_free < data.size()) {
116 temp_ring_buffer->reset();
117 }
118 temp_ring_buffer->write((void *) data.data(), data.size());
119 }
120 });
121
122#ifdef USE_OTA_STATE_LISTENER
124#endif
125}
126
127#ifdef USE_OTA_STATE_LISTENER
128void MicroWakeWord::on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
129 if (state == ota::OTA_STARTED) {
130 this->suspend_task_();
131 } else if (state == ota::OTA_ERROR) {
132 this->resume_task_();
133 }
134}
135#endif
136
138 MicroWakeWord *this_mww = (MicroWakeWord *) params;
139
140 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING);
141
142 { // Ensures any C++ objects fall out of scope to deallocate before deleting the task
143
144 const size_t new_bytes_to_process =
146 std::unique_ptr<audio::AudioSourceTransferBuffer> audio_buffer;
147 int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE];
148
149 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
150 // Allocate audio transfer buffer
151 audio_buffer = audio::AudioSourceTransferBuffer::create(new_bytes_to_process);
152
153 if (audio_buffer == nullptr) {
154 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
155 }
156 }
157
158 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
159 // Allocate ring buffer
160 std::shared_ptr<RingBuffer> temp_ring_buffer = RingBuffer::create(
161 this_mww->microphone_source_->get_audio_stream_info().ms_to_bytes(RING_BUFFER_DURATION_MS));
162 if (temp_ring_buffer.use_count() == 0) {
163 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
164 }
165 audio_buffer->set_source(temp_ring_buffer);
166 this_mww->ring_buffer_ = temp_ring_buffer;
167 }
168
169 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
170 this_mww->microphone_source_->start();
171 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING);
172
173 while (!(xEventGroupGetBits(this_mww->event_group_) & COMMAND_STOP)) {
174 audio_buffer->transfer_data_from_source(pdMS_TO_TICKS(DATA_TIMEOUT_MS));
175
176 if (audio_buffer->available() < new_bytes_to_process) {
177 // Insufficient data to generate new spectrogram features, read more next iteration
178 continue;
179 }
180
181 // Generate new spectrogram features
182 uint32_t processed_samples = this_mww->generate_features_(
183 (int16_t *) audio_buffer->get_buffer_start(), audio_buffer->available() / sizeof(int16_t), features_buffer);
184 audio_buffer->decrease_buffer_length(processed_samples * sizeof(int16_t));
185
186 // Run inference using the new spectorgram features
187 if (!this_mww->update_model_probabilities_(features_buffer)) {
188 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE);
189 break;
190 }
191
192 // Process each model's probabilities and possibly send a Detection Event to the queue
193 this_mww->process_probabilities_();
194 }
195 }
196 }
197
198 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING);
199
200 this_mww->unload_models_();
201 this_mww->microphone_source_->stop();
202 FrontendFreeStateContents(&this_mww->frontend_state_);
203
204 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED);
205 while (true) {
206 // Continuously delay until the main loop deletes the task
207 delay(10);
208 }
209}
210
211std::vector<WakeWordModel *> MicroWakeWord::get_wake_words() {
212 std::vector<WakeWordModel *> external_wake_word_models;
213 for (auto *model : this->wake_word_models_) {
214 if (!model->get_internal_only()) {
215 external_wake_word_models.push_back(model);
216 }
217 }
218 return external_wake_word_models;
219}
220
222
223#ifdef USE_MICRO_WAKE_WORD_VAD
224void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
225 size_t tensor_arena_size) {
226 this->vad_model_ = make_unique<VADModel>(model_start, probability_cutoff, sliding_window_size, tensor_arena_size);
227}
228#endif
229
231 if (this->inference_task_handle_ != nullptr) {
232 vTaskSuspend(this->inference_task_handle_);
233 }
234}
235
237 if (this->inference_task_handle_ != nullptr) {
238 vTaskResume(this->inference_task_handle_);
239 }
240}
241
243 uint32_t event_group_bits = xEventGroupGetBits(this->event_group_);
244
245 if (event_group_bits & EventGroupBits::ERROR_MEMORY) {
246 xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY);
247 ESP_LOGE(TAG, "Encountered an error allocating buffers");
248 }
249
250 if (event_group_bits & EventGroupBits::ERROR_INFERENCE) {
251 xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE);
252 ESP_LOGE(TAG, "Encountered an error while performing an inference");
253 }
254
255 if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) {
256 xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER);
257 ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake "
258 "word detection accuracy will temporarily be reduced.");
259 }
260
261 if (event_group_bits & EventGroupBits::TASK_STARTING) {
262 ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers");
263 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING);
264 }
265
266 if (event_group_bits & EventGroupBits::TASK_RUNNING) {
267 ESP_LOGD(TAG, "Inference task is running");
268
269 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING);
271 }
272
273 if (event_group_bits & EventGroupBits::TASK_STOPPING) {
274 ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers");
275 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING);
276 }
277
278 if ((event_group_bits & EventGroupBits::TASK_STOPPED)) {
279 ESP_LOGD(TAG, "Inference task is finished, freeing task resources");
280 vTaskDelete(this->inference_task_handle_);
281 this->inference_task_handle_ = nullptr;
282 xEventGroupClearBits(this->event_group_, ALL_BITS);
283 xQueueReset(this->detection_queue_);
285 }
286
287 if ((this->pending_start_) && (this->state_ == State::STOPPED)) {
289 this->pending_start_ = false;
290 }
291
292 if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) {
294 this->pending_stop_ = false;
295 }
296
297 switch (this->state_) {
298 case State::STARTING:
299 if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) {
300 // Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it
301 // uses floating point operations.
302 if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_,
304 this->status_momentary_error("frontend_alloc", 1000);
305 return;
306 }
307
308 xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this,
309 INFERENCE_TASK_PRIORITY, &this->inference_task_handle_);
310
311 if (this->inference_task_handle_ == nullptr) {
312 FrontendFreeStateContents(&this->frontend_state_); // Deallocate frontend state
313 this->status_momentary_error("task_start", 1000);
314 }
315 }
316 break;
318 DetectionEvent detection_event;
319 while (xQueueReceive(this->detection_queue_, &detection_event, 0)) {
320 if (detection_event.blocked_by_vad) {
321 ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str());
322 } else {
323 constexpr float uint8_to_float_divisor =
324 255.0f; // Converting a quantized uint8 probability to floating point
325 ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f",
326 detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor),
327 (detection_event.max_probability / uint8_to_float_divisor));
328 this->wake_word_detected_trigger_->trigger(*detection_event.wake_word);
329 if (this->stop_after_detection_) {
330 this->stop();
331 }
332 }
333 }
334 break;
335 }
336 case State::STOPPING:
337 xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP);
338 break;
339 case State::STOPPED:
340 break;
341 }
342}
343
345 if (!this->is_ready()) {
346 ESP_LOGW(TAG, "Wake word detection can't start as the component hasn't been setup yet");
347 return;
348 }
349
350 if (this->is_failed()) {
351 ESP_LOGW(TAG, "Wake word component is marked as failed. Please check setup logs");
352 return;
353 }
354
355 if (this->is_running()) {
356 ESP_LOGW(TAG, "Wake word detection is already running");
357 return;
358 }
359
360 ESP_LOGD(TAG, "Starting wake word detection");
361
362 this->pending_start_ = true;
363 this->pending_stop_ = false;
364}
365
367 if (this->state_ == STOPPED)
368 return;
369
370 ESP_LOGD(TAG, "Stopping wake word detection");
371
372 this->pending_start_ = false;
373 this->pending_stop_ = true;
374}
375
377 if (this->state_ != state) {
378 ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)),
379 LOG_STR_ARG(micro_wake_word_state_to_string(state)));
380 this->state_ = state;
381 }
382}
383
384size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_available,
385 int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]) {
386 size_t processed_samples = 0;
387 struct FrontendOutput frontend_output =
388 FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, &processed_samples);
389
390 for (size_t i = 0; i < frontend_output.size; ++i) {
391 // These scaling values are set to match the TFLite audio frontend int8 output.
392 // The feature pipeline outputs 16-bit signed integers in roughly a 0 to 670
393 // range. In training, these are then arbitrarily divided by 25.6 to get
394 // float values in the rough range of 0.0 to 26.0. This scaling is performed
395 // for historical reasons, to match up with the output of other feature
396 // generators.
397 // The process is then further complicated when we quantize the model. This
398 // means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN)
399 // to 127 (INT8_MAX) signed integer numbers.
400 // All this means that to get matching values from our integer feature
401 // output into the tensor input, we have to perform:
402 // input = (((feature / 25.6) / 26.0) * 256) - 128
403 // To simplify this and perform it in 32-bit integer math, we rearrange to:
404 // input = (feature * 256) / (25.6 * 26.0) - 128
405 constexpr int32_t value_scale = 256;
406 constexpr int32_t value_div = 666; // 666 = 25.6 * 26.0 after rounding
407 int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div;
408
409 value += INT8_MIN; // Adds a -128; i.e., subtracts 128
410 features_buffer[i] = static_cast<int8_t>(clamp<int32_t>(value, INT8_MIN, INT8_MAX));
411 }
412
413 return processed_samples;
414}
415
417#ifdef USE_MICRO_WAKE_WORD_VAD
418 DetectionEvent vad_state = this->vad_model_->determine_detected();
419
420 this->vad_state_ = vad_state.detected; // atomic write, so thread safe
421#endif
422
423 for (auto &model : this->wake_word_models_) {
424 if (model->get_unprocessed_probability_status()) {
425 // Only detect wake words if there is a new probability since the last check
426 DetectionEvent wake_word_state = model->determine_detected();
427 if (wake_word_state.detected) {
428#ifdef USE_MICRO_WAKE_WORD_VAD
429 if (vad_state.detected) {
430#endif
431 xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
432
433 // Wake main loop immediately to process wake word detection
434#if defined(USE_SOCKET_SELECT_SUPPORT) && defined(USE_WAKE_LOOP_THREADSAFE)
436#endif
437
438 model->reset_probabilities();
439#ifdef USE_MICRO_WAKE_WORD_VAD
440 } else {
441 wake_word_state.blocked_by_vad = true;
442 xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
443 }
444#endif
445 }
446 }
447 }
448}
449
451 for (auto &model : this->wake_word_models_) {
452 model->unload_model();
453 }
454#ifdef USE_MICRO_WAKE_WORD_VAD
455 this->vad_model_->unload_model();
456#endif
457}
458
459bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) {
460 bool success = true;
461
462 for (auto &model : this->wake_word_models_) {
463 // Perform inference
464 success = success & model->perform_streaming_inference(audio_features);
465 }
466#ifdef USE_MICRO_WAKE_WORD_VAD
467 success = success & this->vad_model_->perform_streaming_inference(audio_features);
468#endif
469
470 return success;
471}
472
473} // namespace micro_wake_word
474} // namespace esphome
475
476#endif // USE_ESP32
void wake_loop_threadsafe()
Wake the main event loop from a FreeRTOS task Thread-safe, can be called from task context to immedia...
virtual void mark_failed()
Mark this component as failed.
void status_momentary_error(const char *name, uint32_t length=5000)
Set error status flag and automatically clear it after a timeout.
bool is_failed() const
bool is_ready() const
bool status_has_error() const
static std::unique_ptr< RingBuffer > create(size_t len)
void trigger(const Ts &...x)
Inform the parent automation that the event has triggered.
Definition automation.h:204
static std::unique_ptr< AudioSourceTransferBuffer > create(size_t buffer_size)
Creates a new source transfer buffer.
size_t ms_to_bytes(uint32_t ms) const
Converts duration to bytes.
Definition audio.h:73
uint32_t get_sample_rate() const
Definition audio.h:30
void resume_task_()
Resumes the inference task.
microphone::MicrophoneSource * microphone_source_
void process_probabilities_()
Processes any new probabilities for each model.
std::weak_ptr< RingBuffer > ring_buffer_
Trigger< std::string > * wake_word_detected_trigger_
std::vector< WakeWordModel * > wake_word_models_
void suspend_task_()
Suspends the inference task.
void add_wake_word_model(WakeWordModel *model)
bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE])
Runs an inference with each model using the new spectrogram features.
size_t generate_features_(int16_t *audio_buffer, size_t samples_available, int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE])
Generates spectrogram features from an input buffer of audio samples.
std::unique_ptr< VADModel > vad_model_
void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size)
void unload_models_()
Deletes each model's TFLite interpreters and frees tensor arena memory.
std::vector< WakeWordModel * > get_wake_words()
void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override
void add_data_callback(std::function< void(const std::vector< uint8_t > &)> &&data_callback)
audio::AudioStreamInfo get_audio_stream_info()
Gets the AudioStreamInfo of the data after processing.
void add_global_state_listener(OTAGlobalStateListener *listener)
bool state
Definition fan.h:0
__int64 ssize_t
Definition httplib.h:178
OTAGlobalCallback * get_global_ota_callback()
const float AFTER_CONNECTION
For components that should be initialized after a data connection (API/MQTT) is connected.
Definition component.cpp:89
Providing packet encoding functions for exchanging data with a remote host.
Definition a01nyub.cpp:7
void IRAM_ATTR HOT delay(uint32_t ms)
Definition core.cpp:26
Application App
Global storage of Application pointer - only one Application can exist.