ESPHome 2026.6.0-dev
Loading...
Searching...
No Matches
streaming_model.h
Go to the documentation of this file.
1#pragma once
2
3#ifdef USE_ESP32
4
6
8
9#include <tensorflow/lite/core/c/common.h>
10#include <tensorflow/lite/micro/micro_interpreter.h>
11#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
12
14
15static const uint8_t MIN_SLICES_BEFORE_DETECTION = 100;
16static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024;
17
19 std::string *wake_word;
21 bool partially_detection; // Set if the most recent probability exceed the threshold, but the sliding window average
22 // hasn't yet
25 bool blocked_by_vad = false;
26};
27
29 public:
30 virtual void log_model_config() = 0;
32
33 // Performs inference on the given features.
34 // - If the model is enabled but not loaded, it will load it
35 // - If the model is disabled but loaded, it will unload it
36 // Returns true if sucessful or false if there is an error
37 bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]);
38
41
43 void unload_model();
44
46 virtual void enable() { this->enabled_ = true; }
47
49 virtual void disable() { this->enabled_ = false; }
50
52 bool is_enabled() const { return this->enabled_; }
53
55
56 // Quantized probability cutoffs mapping 0.0 - 1.0 to 0 - 255
58 uint8_t get_probability_cutoff() const { return this->probability_cutoff_; }
59 void set_probability_cutoff(uint8_t probability_cutoff) { this->probability_cutoff_ = probability_cutoff; }
60
61 protected:
64 bool load_model_();
68 size_t probe_arena_size_();
70 bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver);
71
72 tflite::MicroMutableOpResolver<20> streaming_op_resolver_;
73
74 bool loaded_{false};
75 bool enabled_{true};
79 int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION};
80
84
85 size_t last_n_index_{0};
87 std::vector<uint8_t> recent_streaming_probabilities_;
88
89 const uint8_t *model_start_;
90 uint8_t *tensor_arena_{nullptr};
91 uint8_t *var_arena_{nullptr};
92 std::unique_ptr<tflite::MicroInterpreter> interpreter_;
93 tflite::MicroResourceVariables *mrv_{nullptr};
94 tflite::MicroAllocator *ma_{nullptr};
95};
96
97class WakeWordModel final : public StreamingModel {
98 public:
109 WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff,
110 size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
111 bool default_enabled, bool internal_only);
112
113 void log_model_config() override;
114
119
120 const std::string &get_id() const { return this->id_; }
121 const std::string &get_wake_word() const { return this->wake_word_; }
122
123 void add_trained_language(const std::string &language) { this->trained_languages_.push_back(language); }
124 const std::vector<std::string> &get_trained_languages() const { return this->trained_languages_; }
125
127 void enable() override;
128
130 void disable() override;
131
132 bool get_internal_only() { return this->internal_only_; }
133
134 protected:
135 std::string id_;
136 std::string wake_word_;
137 std::vector<std::string> trained_languages_;
138
140
142};
143
144class VADModel final : public StreamingModel {
145 public:
146 VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size,
147 size_t tensor_arena_size);
148
149 void log_model_config() override;
150
155};
156
157} // namespace esphome::micro_wake_word
158
159#endif
virtual void disable()
Disable the model. The next performing_streaming_inference call will unload it.
virtual DetectionEvent determine_detected()=0
bool load_model_()
Allocates tensor and variable arenas and sets up the model interpreter.
virtual void enable()
Enable the model. The next performing_streaming_inference call will load it.
std::unique_ptr< tflite::MicroInterpreter > interpreter_
tflite::MicroMutableOpResolver< 20 > streaming_op_resolver_
bool register_streaming_ops_(tflite::MicroMutableOpResolver< 20 > &op_resolver)
Returns true if successfully registered the streaming model's TensorFlow operations.
void reset_probabilities()
Sets all recent_streaming_probabilities to 0 and resets the ignore window count.
std::vector< uint8_t > recent_streaming_probabilities_
size_t probe_arena_size_()
Probes the actual required tensor arena size by trial allocation.
tflite::MicroResourceVariables * mrv_
bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE])
void unload_model()
Destroys the TFLite interpreter and frees the tensor and variable arenas' memory.
void set_probability_cutoff(uint8_t probability_cutoff)
bool is_enabled() const
Return true if the model is enabled.
DetectionEvent determine_detected() override
Checks for voice activity by comparing the max probability in the sliding window with the probability...
VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size)
void enable() override
Enable the model and save to flash. The next performing_streaming_inference call will load it.
const std::string & get_wake_word() const
DetectionEvent determine_detected() override
Checks for the wake word by comparing the mean probability in the sliding window with the probability...
const std::vector< std::string > & get_trained_languages() const
WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, bool default_enabled, bool internal_only)
Constructs a wake word model object.
void add_trained_language(const std::string &language)
void disable() override
Disable the model and save to flash. The next performing_streaming_inference call will unload it.
std::vector< std::string > trained_languages_
static void uint32_t