ESPHome 2025.12.0-dev
Loading...
Searching...
No Matches
micro_wake_word.cpp
Go to the documentation of this file.
1#include "micro_wake_word.h"
2
3#ifdef USE_ESP_IDF
4
6#include "esphome/core/hal.h"
8#include "esphome/core/log.h"
9
11
12#ifdef USE_OTA
14#endif
15
16namespace esphome {
17namespace micro_wake_word {
18
19static const char *const TAG = "micro_wake_word";
20
21static const ssize_t DETECTION_QUEUE_LENGTH = 5;
22
23static const size_t DATA_TIMEOUT_MS = 50;
24
25static const uint32_t RING_BUFFER_DURATION_MS = 120;
26
27static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072;
28static const UBaseType_t INFERENCE_TASK_PRIORITY = 3;
29
30enum EventGroupBits : uint32_t {
31 COMMAND_STOP = (1 << 0), // Signals the inference task should stop
32
33 TASK_STARTING = (1 << 3),
34 TASK_RUNNING = (1 << 4),
35 TASK_STOPPING = (1 << 5),
36 TASK_STOPPED = (1 << 6),
37
38 ERROR_MEMORY = (1 << 9),
39 ERROR_INFERENCE = (1 << 10),
40
42
44 ALL_BITS = 0xfffff, // 24 total bits available in an event group
45};
46
48
49static const LogString *micro_wake_word_state_to_string(State state) {
50 switch (state) {
51 case State::STARTING:
52 return LOG_STR("STARTING");
54 return LOG_STR("DETECTING_WAKE_WORD");
55 case State::STOPPING:
56 return LOG_STR("STOPPING");
57 case State::STOPPED:
58 return LOG_STR("STOPPED");
59 default:
60 return LOG_STR("UNKNOWN");
61 }
62}
63
65 ESP_LOGCONFIG(TAG, "microWakeWord:");
66 ESP_LOGCONFIG(TAG, " models:");
67 for (auto &model : this->wake_word_models_) {
68 model->log_model_config();
69 }
70#ifdef USE_MICRO_WAKE_WORD_VAD
71 this->vad_model_->log_model_config();
72#endif
73}
74
76 this->frontend_config_.window.size_ms = FEATURE_DURATION_MS;
77 this->frontend_config_.window.step_size_ms = this->features_step_size_;
78 this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE;
79 this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT;
80 this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT;
81 this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS;
82 this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING;
83 this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING;
84 this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING;
85 this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN;
86 this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH;
87 this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET;
88 this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS;
89 this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG;
90 this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT;
91
92 this->event_group_ = xEventGroupCreate();
93 if (this->event_group_ == nullptr) {
94 ESP_LOGE(TAG, "Failed to create event group");
95 this->mark_failed();
96 return;
97 }
98
99 this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent));
100 if (this->detection_queue_ == nullptr) {
101 ESP_LOGE(TAG, "Failed to create detection event queue");
102 this->mark_failed();
103 return;
104 }
105
106 this->microphone_source_->add_data_callback([this](const std::vector<uint8_t> &data) {
107 if (this->state_ == State::STOPPED) {
108 return;
109 }
110 std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_.lock();
111 if (this->ring_buffer_.use_count() > 1) {
112 size_t bytes_free = temp_ring_buffer->free();
113
114 if (bytes_free < data.size()) {
116 temp_ring_buffer->reset();
117 }
118 temp_ring_buffer->write((void *) data.data(), data.size());
119 }
120 });
121
122#ifdef USE_OTA
124 [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
125 if (state == ota::OTA_STARTED) {
126 this->suspend_task_();
127 } else if (state == ota::OTA_ERROR) {
128 this->resume_task_();
129 }
130 });
131#endif
132}
133
135 MicroWakeWord *this_mww = (MicroWakeWord *) params;
136
137 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING);
138
139 { // Ensures any C++ objects fall out of scope to deallocate before deleting the task
140
141 const size_t new_bytes_to_process =
143 std::unique_ptr<audio::AudioSourceTransferBuffer> audio_buffer;
144 int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE];
145
146 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
147 // Allocate audio transfer buffer
148 audio_buffer = audio::AudioSourceTransferBuffer::create(new_bytes_to_process);
149
150 if (audio_buffer == nullptr) {
151 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
152 }
153 }
154
155 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
156 // Allocate ring buffer
157 std::shared_ptr<RingBuffer> temp_ring_buffer = RingBuffer::create(
158 this_mww->microphone_source_->get_audio_stream_info().ms_to_bytes(RING_BUFFER_DURATION_MS));
159 if (temp_ring_buffer.use_count() == 0) {
160 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
161 }
162 audio_buffer->set_source(temp_ring_buffer);
163 this_mww->ring_buffer_ = temp_ring_buffer;
164 }
165
166 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
167 this_mww->microphone_source_->start();
168 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING);
169
170 while (!(xEventGroupGetBits(this_mww->event_group_) & COMMAND_STOP)) {
171 audio_buffer->transfer_data_from_source(pdMS_TO_TICKS(DATA_TIMEOUT_MS));
172
173 if (audio_buffer->available() < new_bytes_to_process) {
174 // Insufficient data to generate new spectrogram features, read more next iteration
175 continue;
176 }
177
178 // Generate new spectrogram features
179 uint32_t processed_samples = this_mww->generate_features_(
180 (int16_t *) audio_buffer->get_buffer_start(), audio_buffer->available() / sizeof(int16_t), features_buffer);
181 audio_buffer->decrease_buffer_length(processed_samples * sizeof(int16_t));
182
183 // Run inference using the new spectorgram features
184 if (!this_mww->update_model_probabilities_(features_buffer)) {
185 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE);
186 break;
187 }
188
189 // Process each model's probabilities and possibly send a Detection Event to the queue
190 this_mww->process_probabilities_();
191 }
192 }
193 }
194
195 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING);
196
197 this_mww->unload_models_();
198 this_mww->microphone_source_->stop();
199 FrontendFreeStateContents(&this_mww->frontend_state_);
200
201 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED);
202 while (true) {
203 // Continuously delay until the main loop deletes the task
204 delay(10);
205 }
206}
207
208std::vector<WakeWordModel *> MicroWakeWord::get_wake_words() {
209 std::vector<WakeWordModel *> external_wake_word_models;
210 for (auto *model : this->wake_word_models_) {
211 if (!model->get_internal_only()) {
212 external_wake_word_models.push_back(model);
213 }
214 }
215 return external_wake_word_models;
216}
217
219
220#ifdef USE_MICRO_WAKE_WORD_VAD
221void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
222 size_t tensor_arena_size) {
223 this->vad_model_ = make_unique<VADModel>(model_start, probability_cutoff, sliding_window_size, tensor_arena_size);
224}
225#endif
226
228 if (this->inference_task_handle_ != nullptr) {
229 vTaskSuspend(this->inference_task_handle_);
230 }
231}
232
234 if (this->inference_task_handle_ != nullptr) {
235 vTaskResume(this->inference_task_handle_);
236 }
237}
238
240 uint32_t event_group_bits = xEventGroupGetBits(this->event_group_);
241
242 if (event_group_bits & EventGroupBits::ERROR_MEMORY) {
243 xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY);
244 ESP_LOGE(TAG, "Encountered an error allocating buffers");
245 }
246
247 if (event_group_bits & EventGroupBits::ERROR_INFERENCE) {
248 xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE);
249 ESP_LOGE(TAG, "Encountered an error while performing an inference");
250 }
251
252 if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) {
253 xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER);
254 ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake "
255 "word detection accuracy will temporarily be reduced.");
256 }
257
258 if (event_group_bits & EventGroupBits::TASK_STARTING) {
259 ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers");
260 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING);
261 }
262
263 if (event_group_bits & EventGroupBits::TASK_RUNNING) {
264 ESP_LOGD(TAG, "Inference task is running");
265
266 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING);
268 }
269
270 if (event_group_bits & EventGroupBits::TASK_STOPPING) {
271 ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers");
272 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING);
273 }
274
275 if ((event_group_bits & EventGroupBits::TASK_STOPPED)) {
276 ESP_LOGD(TAG, "Inference task is finished, freeing task resources");
277 vTaskDelete(this->inference_task_handle_);
278 this->inference_task_handle_ = nullptr;
279 xEventGroupClearBits(this->event_group_, ALL_BITS);
280 xQueueReset(this->detection_queue_);
282 }
283
284 if ((this->pending_start_) && (this->state_ == State::STOPPED)) {
286 this->pending_start_ = false;
287 }
288
289 if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) {
291 this->pending_stop_ = false;
292 }
293
294 switch (this->state_) {
295 case State::STARTING:
296 if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) {
297 // Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it
298 // uses floating point operations.
299 if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_,
302 "Failed to allocate buffers for spectrogram feature processor, attempting again in 1 second", 1000);
303 return;
304 }
305
306 xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this,
307 INFERENCE_TASK_PRIORITY, &this->inference_task_handle_);
308
309 if (this->inference_task_handle_ == nullptr) {
310 FrontendFreeStateContents(&this->frontend_state_); // Deallocate frontend state
311 this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000);
312 }
313 }
314 break;
316 DetectionEvent detection_event;
317 while (xQueueReceive(this->detection_queue_, &detection_event, 0)) {
318 if (detection_event.blocked_by_vad) {
319 ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str());
320 } else {
321 constexpr float uint8_to_float_divisor =
322 255.0f; // Converting a quantized uint8 probability to floating point
323 ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f",
324 detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor),
325 (detection_event.max_probability / uint8_to_float_divisor));
326 this->wake_word_detected_trigger_->trigger(*detection_event.wake_word);
327 if (this->stop_after_detection_) {
328 this->stop();
329 }
330 }
331 }
332 break;
333 }
334 case State::STOPPING:
335 xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP);
336 break;
337 case State::STOPPED:
338 break;
339 }
340}
341
343 if (!this->is_ready()) {
344 ESP_LOGW(TAG, "Wake word detection can't start as the component hasn't been setup yet");
345 return;
346 }
347
348 if (this->is_failed()) {
349 ESP_LOGW(TAG, "Wake word component is marked as failed. Please check setup logs");
350 return;
351 }
352
353 if (this->is_running()) {
354 ESP_LOGW(TAG, "Wake word detection is already running");
355 return;
356 }
357
358 ESP_LOGD(TAG, "Starting wake word detection");
359
360 this->pending_start_ = true;
361 this->pending_stop_ = false;
362}
363
365 if (this->state_ == STOPPED)
366 return;
367
368 ESP_LOGD(TAG, "Stopping wake word detection");
369
370 this->pending_start_ = false;
371 this->pending_stop_ = true;
372}
373
375 if (this->state_ != state) {
376 ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)),
377 LOG_STR_ARG(micro_wake_word_state_to_string(state)));
378 this->state_ = state;
379 }
380}
381
382size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_available,
383 int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]) {
384 size_t processed_samples = 0;
385 struct FrontendOutput frontend_output =
386 FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, &processed_samples);
387
388 for (size_t i = 0; i < frontend_output.size; ++i) {
389 // These scaling values are set to match the TFLite audio frontend int8 output.
390 // The feature pipeline outputs 16-bit signed integers in roughly a 0 to 670
391 // range. In training, these are then arbitrarily divided by 25.6 to get
392 // float values in the rough range of 0.0 to 26.0. This scaling is performed
393 // for historical reasons, to match up with the output of other feature
394 // generators.
395 // The process is then further complicated when we quantize the model. This
396 // means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN)
397 // to 127 (INT8_MAX) signed integer numbers.
398 // All this means that to get matching values from our integer feature
399 // output into the tensor input, we have to perform:
400 // input = (((feature / 25.6) / 26.0) * 256) - 128
401 // To simplify this and perform it in 32-bit integer math, we rearrange to:
402 // input = (feature * 256) / (25.6 * 26.0) - 128
403 constexpr int32_t value_scale = 256;
404 constexpr int32_t value_div = 666; // 666 = 25.6 * 26.0 after rounding
405 int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div;
406
407 value += INT8_MIN; // Adds a -128; i.e., subtracts 128
408 features_buffer[i] = static_cast<int8_t>(clamp<int32_t>(value, INT8_MIN, INT8_MAX));
409 }
410
411 return processed_samples;
412}
413
415#ifdef USE_MICRO_WAKE_WORD_VAD
416 DetectionEvent vad_state = this->vad_model_->determine_detected();
417
418 this->vad_state_ = vad_state.detected; // atomic write, so thread safe
419#endif
420
421 for (auto &model : this->wake_word_models_) {
422 if (model->get_unprocessed_probability_status()) {
423 // Only detect wake words if there is a new probability since the last check
424 DetectionEvent wake_word_state = model->determine_detected();
425 if (wake_word_state.detected) {
426#ifdef USE_MICRO_WAKE_WORD_VAD
427 if (vad_state.detected) {
428#endif
429 xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
430
431 // Wake main loop immediately to process wake word detection
432#if defined(USE_SOCKET_SELECT_SUPPORT) && defined(USE_WAKE_LOOP_THREADSAFE)
434#endif
435
436 model->reset_probabilities();
437#ifdef USE_MICRO_WAKE_WORD_VAD
438 } else {
439 wake_word_state.blocked_by_vad = true;
440 xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
441 }
442#endif
443 }
444 }
445 }
446}
447
449 for (auto &model : this->wake_word_models_) {
450 model->unload_model();
451 }
452#ifdef USE_MICRO_WAKE_WORD_VAD
453 this->vad_model_->unload_model();
454#endif
455}
456
457bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) {
458 bool success = true;
459
460 for (auto &model : this->wake_word_models_) {
461 // Perform inference
462 success = success & model->perform_streaming_inference(audio_features);
463 }
464#ifdef USE_MICRO_WAKE_WORD_VAD
465 success = success & this->vad_model_->perform_streaming_inference(audio_features);
466#endif
467
468 return success;
469}
470
471} // namespace micro_wake_word
472} // namespace esphome
473
474#endif // USE_ESP_IDF
void wake_loop_threadsafe()
Wake the main event loop from a FreeRTOS task Thread-safe, can be called from task context to immedia...
virtual void mark_failed()
Mark this component as failed.
bool is_failed() const
void status_momentary_error(const std::string &name, uint32_t length=5000)
bool is_ready() const
bool status_has_error() const
static std::unique_ptr< RingBuffer > create(size_t len)
void trigger(const Ts &...x)
Inform the parent automation that the event has triggered.
Definition automation.h:169
static std::unique_ptr< AudioSourceTransferBuffer > create(size_t buffer_size)
Creates a new source transfer buffer.
size_t ms_to_bytes(uint32_t ms) const
Converts duration to bytes.
Definition audio.h:73
uint32_t get_sample_rate() const
Definition audio.h:30
void resume_task_()
Resumes the inference task.
microphone::MicrophoneSource * microphone_source_
void process_probabilities_()
Processes any new probabilities for each model.
std::weak_ptr< RingBuffer > ring_buffer_
Trigger< std::string > * wake_word_detected_trigger_
std::vector< WakeWordModel * > wake_word_models_
void suspend_task_()
Suspends the inference task.
void add_wake_word_model(WakeWordModel *model)
bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE])
Runs an inference with each model using the new spectrogram features.
size_t generate_features_(int16_t *audio_buffer, size_t samples_available, int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE])
Generates spectrogram features from an input buffer of audio samples.
std::unique_ptr< VADModel > vad_model_
void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size)
void unload_models_()
Deletes each model's TFLite interpreters and frees tensor arena memory.
std::vector< WakeWordModel * > get_wake_words()
void add_data_callback(std::function< void(const std::vector< uint8_t > &)> &&data_callback)
audio::AudioStreamInfo get_audio_stream_info()
Gets the AudioStreamInfo of the data after processing.
void add_on_state_callback(std::function< void(OTAState, float, uint8_t, OTAComponent *)> &&callback)
bool state
Definition fan.h:0
__int64 ssize_t
Definition httplib.h:178
OTAGlobalCallback * get_global_ota_callback()
const float AFTER_CONNECTION
For components that should be initialized after a data connection (API/MQTT) is connected.
Definition component.cpp:67
Providing packet encoding functions for exchanging data with a remote host.
Definition a01nyub.cpp:7
void IRAM_ATTR HOT delay(uint32_t ms)
Definition core.cpp:31
Application App
Global storage of Application pointer - only one Application can exist.