ESPHome 2025.9.0-dev
Loading...
Searching...
No Matches
micro_wake_word.cpp
Go to the documentation of this file.
1#include "micro_wake_word.h"
2
3#ifdef USE_ESP_IDF
4
5#include "esphome/core/hal.h"
7#include "esphome/core/log.h"
8
10
11#ifdef USE_OTA
13#endif
14
15namespace esphome {
16namespace micro_wake_word {
17
18static const char *const TAG = "micro_wake_word";
19
20static const ssize_t DETECTION_QUEUE_LENGTH = 5;
21
22static const size_t DATA_TIMEOUT_MS = 50;
23
24static const uint32_t RING_BUFFER_DURATION_MS = 120;
25
26static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072;
27static const UBaseType_t INFERENCE_TASK_PRIORITY = 3;
28
29enum EventGroupBits : uint32_t {
30 COMMAND_STOP = (1 << 0), // Signals the inference task should stop
31
32 TASK_STARTING = (1 << 3),
33 TASK_RUNNING = (1 << 4),
34 TASK_STOPPING = (1 << 5),
35 TASK_STOPPED = (1 << 6),
36
37 ERROR_MEMORY = (1 << 9),
38 ERROR_INFERENCE = (1 << 10),
39
41
43 ALL_BITS = 0xfffff, // 24 total bits available in an event group
44};
45
47
48static const LogString *micro_wake_word_state_to_string(State state) {
49 switch (state) {
50 case State::STARTING:
51 return LOG_STR("STARTING");
53 return LOG_STR("DETECTING_WAKE_WORD");
54 case State::STOPPING:
55 return LOG_STR("STOPPING");
56 case State::STOPPED:
57 return LOG_STR("STOPPED");
58 default:
59 return LOG_STR("UNKNOWN");
60 }
61}
62
64 ESP_LOGCONFIG(TAG, "microWakeWord:");
65 ESP_LOGCONFIG(TAG, " models:");
66 for (auto &model : this->wake_word_models_) {
67 model->log_model_config();
68 }
69#ifdef USE_MICRO_WAKE_WORD_VAD
70 this->vad_model_->log_model_config();
71#endif
72}
73
75 this->frontend_config_.window.size_ms = FEATURE_DURATION_MS;
76 this->frontend_config_.window.step_size_ms = this->features_step_size_;
77 this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE;
78 this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT;
79 this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT;
80 this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS;
81 this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING;
82 this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING;
83 this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING;
84 this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN;
85 this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH;
86 this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET;
87 this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS;
88 this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG;
89 this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT;
90
91 this->event_group_ = xEventGroupCreate();
92 if (this->event_group_ == nullptr) {
93 ESP_LOGE(TAG, "Failed to create event group");
94 this->mark_failed();
95 return;
96 }
97
98 this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent));
99 if (this->detection_queue_ == nullptr) {
100 ESP_LOGE(TAG, "Failed to create detection event queue");
101 this->mark_failed();
102 return;
103 }
104
105 this->microphone_source_->add_data_callback([this](const std::vector<uint8_t> &data) {
106 if (this->state_ == State::STOPPED) {
107 return;
108 }
109 std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_.lock();
110 if (this->ring_buffer_.use_count() > 1) {
111 size_t bytes_free = temp_ring_buffer->free();
112
113 if (bytes_free < data.size()) {
115 temp_ring_buffer->reset();
116 }
117 temp_ring_buffer->write((void *) data.data(), data.size());
118 }
119 });
120
121#ifdef USE_OTA
123 [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
124 if (state == ota::OTA_STARTED) {
125 this->suspend_task_();
126 } else if (state == ota::OTA_ERROR) {
127 this->resume_task_();
128 }
129 });
130#endif
131}
132
134 MicroWakeWord *this_mww = (MicroWakeWord *) params;
135
136 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING);
137
138 { // Ensures any C++ objects fall out of scope to deallocate before deleting the task
139
140 const size_t new_bytes_to_process =
142 std::unique_ptr<audio::AudioSourceTransferBuffer> audio_buffer;
143 int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE];
144
145 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
146 // Allocate audio transfer buffer
147 audio_buffer = audio::AudioSourceTransferBuffer::create(new_bytes_to_process);
148
149 if (audio_buffer == nullptr) {
150 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
151 }
152 }
153
154 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
155 // Allocate ring buffer
156 std::shared_ptr<RingBuffer> temp_ring_buffer = RingBuffer::create(
157 this_mww->microphone_source_->get_audio_stream_info().ms_to_bytes(RING_BUFFER_DURATION_MS));
158 if (temp_ring_buffer.use_count() == 0) {
159 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
160 }
161 audio_buffer->set_source(temp_ring_buffer);
162 this_mww->ring_buffer_ = temp_ring_buffer;
163 }
164
165 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
166 this_mww->microphone_source_->start();
167 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING);
168
169 while (!(xEventGroupGetBits(this_mww->event_group_) & COMMAND_STOP)) {
170 audio_buffer->transfer_data_from_source(pdMS_TO_TICKS(DATA_TIMEOUT_MS));
171
172 if (audio_buffer->available() < new_bytes_to_process) {
173 // Insufficient data to generate new spectrogram features, read more next iteration
174 continue;
175 }
176
177 // Generate new spectrogram features
178 uint32_t processed_samples = this_mww->generate_features_(
179 (int16_t *) audio_buffer->get_buffer_start(), audio_buffer->available() / sizeof(int16_t), features_buffer);
180 audio_buffer->decrease_buffer_length(processed_samples * sizeof(int16_t));
181
182 // Run inference using the new spectorgram features
183 if (!this_mww->update_model_probabilities_(features_buffer)) {
184 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE);
185 break;
186 }
187
188 // Process each model's probabilities and possibly send a Detection Event to the queue
189 this_mww->process_probabilities_();
190 }
191 }
192 }
193
194 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING);
195
196 this_mww->unload_models_();
197 this_mww->microphone_source_->stop();
198 FrontendFreeStateContents(&this_mww->frontend_state_);
199
200 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED);
201 while (true) {
202 // Continuously delay until the main loop deletes the task
203 delay(10);
204 }
205}
206
207std::vector<WakeWordModel *> MicroWakeWord::get_wake_words() {
208 std::vector<WakeWordModel *> external_wake_word_models;
209 for (auto *model : this->wake_word_models_) {
210 if (!model->get_internal_only()) {
211 external_wake_word_models.push_back(model);
212 }
213 }
214 return external_wake_word_models;
215}
216
218
219#ifdef USE_MICRO_WAKE_WORD_VAD
220void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
221 size_t tensor_arena_size) {
222 this->vad_model_ = make_unique<VADModel>(model_start, probability_cutoff, sliding_window_size, tensor_arena_size);
223}
224#endif
225
227 if (this->inference_task_handle_ != nullptr) {
228 vTaskSuspend(this->inference_task_handle_);
229 }
230}
231
233 if (this->inference_task_handle_ != nullptr) {
234 vTaskResume(this->inference_task_handle_);
235 }
236}
237
239 uint32_t event_group_bits = xEventGroupGetBits(this->event_group_);
240
241 if (event_group_bits & EventGroupBits::ERROR_MEMORY) {
242 xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY);
243 ESP_LOGE(TAG, "Encountered an error allocating buffers");
244 }
245
246 if (event_group_bits & EventGroupBits::ERROR_INFERENCE) {
247 xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE);
248 ESP_LOGE(TAG, "Encountered an error while performing an inference");
249 }
250
251 if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) {
252 xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER);
253 ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake "
254 "word detection accuracy will temporarily be reduced.");
255 }
256
257 if (event_group_bits & EventGroupBits::TASK_STARTING) {
258 ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers");
259 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING);
260 }
261
262 if (event_group_bits & EventGroupBits::TASK_RUNNING) {
263 ESP_LOGD(TAG, "Inference task is running");
264
265 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING);
267 }
268
269 if (event_group_bits & EventGroupBits::TASK_STOPPING) {
270 ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers");
271 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING);
272 }
273
274 if ((event_group_bits & EventGroupBits::TASK_STOPPED)) {
275 ESP_LOGD(TAG, "Inference task is finished, freeing task resources");
276 vTaskDelete(this->inference_task_handle_);
277 this->inference_task_handle_ = nullptr;
278 xEventGroupClearBits(this->event_group_, ALL_BITS);
279 xQueueReset(this->detection_queue_);
281 }
282
283 if ((this->pending_start_) && (this->state_ == State::STOPPED)) {
285 this->pending_start_ = false;
286 }
287
288 if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) {
290 this->pending_stop_ = false;
291 }
292
293 switch (this->state_) {
294 case State::STARTING:
295 if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) {
296 // Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it
297 // uses floating point operations.
298 if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_,
301 "Failed to allocate buffers for spectrogram feature processor, attempting again in 1 second", 1000);
302 return;
303 }
304
305 xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this,
306 INFERENCE_TASK_PRIORITY, &this->inference_task_handle_);
307
308 if (this->inference_task_handle_ == nullptr) {
309 FrontendFreeStateContents(&this->frontend_state_); // Deallocate frontend state
310 this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000);
311 }
312 }
313 break;
315 DetectionEvent detection_event;
316 while (xQueueReceive(this->detection_queue_, &detection_event, 0)) {
317 if (detection_event.blocked_by_vad) {
318 ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str());
319 } else {
320 constexpr float uint8_to_float_divisor =
321 255.0f; // Converting a quantized uint8 probability to floating point
322 ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f",
323 detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor),
324 (detection_event.max_probability / uint8_to_float_divisor));
325 this->wake_word_detected_trigger_->trigger(*detection_event.wake_word);
326 if (this->stop_after_detection_) {
327 this->stop();
328 }
329 }
330 }
331 break;
332 }
333 case State::STOPPING:
334 xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP);
335 break;
336 case State::STOPPED:
337 break;
338 }
339}
340
342 if (!this->is_ready()) {
343 ESP_LOGW(TAG, "Wake word detection can't start as the component hasn't been setup yet");
344 return;
345 }
346
347 if (this->is_failed()) {
348 ESP_LOGW(TAG, "Wake word component is marked as failed. Please check setup logs");
349 return;
350 }
351
352 if (this->is_running()) {
353 ESP_LOGW(TAG, "Wake word detection is already running");
354 return;
355 }
356
357 ESP_LOGD(TAG, "Starting wake word detection");
358
359 this->pending_start_ = true;
360 this->pending_stop_ = false;
361}
362
364 if (this->state_ == STOPPED)
365 return;
366
367 ESP_LOGD(TAG, "Stopping wake word detection");
368
369 this->pending_start_ = false;
370 this->pending_stop_ = true;
371}
372
374 if (this->state_ != state) {
375 ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)),
376 LOG_STR_ARG(micro_wake_word_state_to_string(state)));
377 this->state_ = state;
378 }
379}
380
381size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_available,
382 int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]) {
383 size_t processed_samples = 0;
384 struct FrontendOutput frontend_output =
385 FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, &processed_samples);
386
387 for (size_t i = 0; i < frontend_output.size; ++i) {
388 // These scaling values are set to match the TFLite audio frontend int8 output.
389 // The feature pipeline outputs 16-bit signed integers in roughly a 0 to 670
390 // range. In training, these are then arbitrarily divided by 25.6 to get
391 // float values in the rough range of 0.0 to 26.0. This scaling is performed
392 // for historical reasons, to match up with the output of other feature
393 // generators.
394 // The process is then further complicated when we quantize the model. This
395 // means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN)
396 // to 127 (INT8_MAX) signed integer numbers.
397 // All this means that to get matching values from our integer feature
398 // output into the tensor input, we have to perform:
399 // input = (((feature / 25.6) / 26.0) * 256) - 128
400 // To simplify this and perform it in 32-bit integer math, we rearrange to:
401 // input = (feature * 256) / (25.6 * 26.0) - 128
402 constexpr int32_t value_scale = 256;
403 constexpr int32_t value_div = 666; // 666 = 25.6 * 26.0 after rounding
404 int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div;
405
406 value += INT8_MIN; // Adds a -128; i.e., subtracts 128
407 features_buffer[i] = static_cast<int8_t>(clamp<int32_t>(value, INT8_MIN, INT8_MAX));
408 }
409
410 return processed_samples;
411}
412
414#ifdef USE_MICRO_WAKE_WORD_VAD
415 DetectionEvent vad_state = this->vad_model_->determine_detected();
416
417 this->vad_state_ = vad_state.detected; // atomic write, so thread safe
418#endif
419
420 for (auto &model : this->wake_word_models_) {
421 if (model->get_unprocessed_probability_status()) {
422 // Only detect wake words if there is a new probability since the last check
423 DetectionEvent wake_word_state = model->determine_detected();
424 if (wake_word_state.detected) {
425#ifdef USE_MICRO_WAKE_WORD_VAD
426 if (vad_state.detected) {
427#endif
428 xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
429 model->reset_probabilities();
430#ifdef USE_MICRO_WAKE_WORD_VAD
431 } else {
432 wake_word_state.blocked_by_vad = true;
433 xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
434 }
435#endif
436 }
437 }
438 }
439}
440
442 for (auto &model : this->wake_word_models_) {
443 model->unload_model();
444 }
445#ifdef USE_MICRO_WAKE_WORD_VAD
446 this->vad_model_->unload_model();
447#endif
448}
449
450bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) {
451 bool success = true;
452
453 for (auto &model : this->wake_word_models_) {
454 // Perform inference
455 success = success & model->perform_streaming_inference(audio_features);
456 }
457#ifdef USE_MICRO_WAKE_WORD_VAD
458 success = success & this->vad_model_->perform_streaming_inference(audio_features);
459#endif
460
461 return success;
462}
463
464} // namespace micro_wake_word
465} // namespace esphome
466
467#endif // USE_ESP_IDF
virtual void mark_failed()
Mark this component as failed.
bool is_failed() const
void status_momentary_error(const std::string &name, uint32_t length=5000)
bool is_ready() const
bool status_has_error() const
static std::unique_ptr< RingBuffer > create(size_t len)
void trigger(Ts... x)
Inform the parent automation that the event has triggered.
Definition automation.h:145
static std::unique_ptr< AudioSourceTransferBuffer > create(size_t buffer_size)
Creates a new source transfer buffer.
size_t ms_to_bytes(uint32_t ms) const
Converts duration to bytes.
Definition audio.h:73
uint32_t get_sample_rate() const
Definition audio.h:30
void resume_task_()
Resumes the inference task.
microphone::MicrophoneSource * microphone_source_
void process_probabilities_()
Processes any new probabilities for each model.
std::weak_ptr< RingBuffer > ring_buffer_
Trigger< std::string > * wake_word_detected_trigger_
std::vector< WakeWordModel * > wake_word_models_
void suspend_task_()
Suspends the inference task.
void add_wake_word_model(WakeWordModel *model)
bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE])
Runs an inference with each model using the new spectrogram features.
size_t generate_features_(int16_t *audio_buffer, size_t samples_available, int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE])
Generates spectrogram features from an input buffer of audio samples.
std::unique_ptr< VADModel > vad_model_
void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size)
void unload_models_()
Deletes each model's TFLite interpreters and frees tensor arena memory.
std::vector< WakeWordModel * > get_wake_words()
void add_data_callback(std::function< void(const std::vector< uint8_t > &)> &&data_callback)
audio::AudioStreamInfo get_audio_stream_info()
Gets the AudioStreamInfo of the data after processing.
void add_on_state_callback(std::function< void(OTAState, float, uint8_t, OTAComponent *)> &&callback)
bool state
Definition fan.h:0
__int64 ssize_t
Definition httplib.h:175
OTAGlobalCallback * get_global_ota_callback()
const float AFTER_CONNECTION
For components that should be initialized after a data connection (API/MQTT) is connected.
Definition component.cpp:58
Providing packet encoding functions for exchanging data with a remote host.
Definition a01nyub.cpp:7
void IRAM_ATTR HOT delay(uint32_t ms)
Definition core.cpp:29