From 1090fdb397a341c15e635973750afedb6ccc8e44 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 May 2022 11:47:39 +0300 Subject: [PATCH] ecc : reduce memory allocations in reed-solomon lib --- include/ggwave/ggwave.h | 39 +++++++++---------- src/ggwave.cpp | 84 +++++++++++++++++++++++------------------ src/reed-solomon/rs.hpp | 31 ++++++++++----- tests/test-ggwave.cpp | 4 +- 4 files changed, 91 insertions(+), 67 deletions(-) diff --git a/include/ggwave/ggwave.h b/include/ggwave/ggwave.h index daf08ef..fa1c0d9 100644 --- a/include/ggwave/ggwave.h +++ b/include/ggwave/ggwave.h @@ -298,26 +298,25 @@ extern "C" { #include #include #include -#include #include class GGWave { public: - static constexpr auto kSampleRateMin = 1000.0f; - static constexpr auto kSampleRateMax = 96000.0f; - static constexpr auto kDefaultSampleRate = 48000.0f; - static constexpr auto kDefaultSamplesPerFrame = 1024; - static constexpr auto kDefaultVolume = 10; + static constexpr auto kSampleRateMin = 1000.0f; + static constexpr auto kSampleRateMax = 96000.0f; + static constexpr auto kDefaultSampleRate = 48000.0f; + static constexpr auto kDefaultSamplesPerFrame = 1024; + static constexpr auto kDefaultVolume = 10; static constexpr auto kDefaultSoundMarkerThreshold = 3.0f; - static constexpr auto kDefaultMarkerFrames = 16; - static constexpr auto kDefaultEncodedDataOffset = 3; - static constexpr auto kMaxSamplesPerFrame = 1024; - static constexpr auto kMaxDataBits = 256; - static constexpr auto kMaxDataSize = 256; - static constexpr auto kMaxLengthVarible = 140; - static constexpr auto kMaxLengthFixed = 16; - static constexpr auto kMaxSpectrumHistory = 4; - static constexpr auto kMaxRecordedFrames = 2048; + static constexpr auto kDefaultMarkerFrames = 16; + static constexpr auto kDefaultEncodedDataOffset = 3; + static constexpr auto kMaxSamplesPerFrame = 1024; + static constexpr auto kMaxDataBits = 256; + static constexpr auto kMaxDataSize = 256; + static constexpr auto kMaxLengthVariable = 140; + static constexpr auto kMaxLengthFixed = 16; + static constexpr auto kMaxSpectrumHistory = 4; + static constexpr auto kMaxRecordedFrames = 2048; using Parameters = ggwave_Parameters; using SampleFormat = ggwave_SampleFormat; @@ -361,14 +360,14 @@ public: double duration_ms; }; - using Tones = std::vector; + using Tones = std::vector; using WaveformTones = std::vector; using AmplitudeData = std::vector; using AmplitudeDataI16 = std::vector; using SpectrumData = std::vector; using RecordedData = std::vector; - using TxRxData = std::vector; + using TxRxData = std::vector; using CBWaveformOut = std::function; using CBWaveformInp = std::function; @@ -394,8 +393,8 @@ public: // // returns false upon invalid parameters or failure to initialize // - bool init(const std::string & text, const int volume = kDefaultVolume); - bool init(const std::string & text, const TxProtocol & txProtocol, const int volume = kDefaultVolume); + bool init(const char * text, const int volume = kDefaultVolume); + bool init(const char * text, const TxProtocol & txProtocol, const int volume = kDefaultVolume); bool init(int dataSize, const char * dataBuffer, const int volume = kDefaultVolume); bool init(int dataSize, const char * dataBuffer, const TxProtocol & txProtocol, const int volume = kDefaultVolume); @@ -576,6 +575,8 @@ private: // common TxRxData m_dataEncoded; + TxRxData m_workRSLength; // Reed-Solomon work buffers + TxRxData m_workRSData; // Impl struct Rx; diff --git a/src/ggwave.cpp b/src/ggwave.cpp index 1786140..714690f 100644 --- a/src/ggwave.cpp +++ b/src/ggwave.cpp @@ -4,8 +4,6 @@ #include #include -#include -#include #include //#include @@ -493,10 +491,10 @@ GGWave::GGWave(const Parameters & parameters) : m_rx->fftInp.resize(m_samplesPerFrame); m_rx->fftOut.resize(2*m_samplesPerFrame); - m_rx->sampleSpectrum.resize (m_samplesPerFrame); - m_rx->sampleAmplitude.resize (m_needResampling ? m_samplesPerFrame + 128 : m_samplesPerFrame); // small extra space because sometimes resampling needs a few more samples + m_rx->sampleSpectrum.resize(m_samplesPerFrame); + m_rx->sampleAmplitude.resize(m_needResampling ? m_samplesPerFrame + 128 : m_samplesPerFrame); // small extra space because sometimes resampling needs a few more samples m_rx->sampleAmplitudeResampled.resize(m_needResampling ? 8*m_samplesPerFrame : m_samplesPerFrame); // min input sampling rate is 0.125*m_sampleRate - m_rx->sampleAmplitudeTmp.resize (m_needResampling ? 8*m_samplesPerFrame*m_sampleSizeBytesInp : m_samplesPerFrame*m_sampleSizeBytesInp); + m_rx->sampleAmplitudeTmp.resize(m_needResampling ? 8*m_samplesPerFrame*m_sampleSizeBytesInp : m_samplesPerFrame*m_sampleSizeBytesInp); m_rx->rxData.resize(kMaxDataSize); @@ -510,8 +508,8 @@ GGWave::GGWave(const Parameters & parameters) : return; } - int totalLength = m_payloadLength + getECCBytesForLength(m_payloadLength); - int totalTxs = (totalLength + minBytesPerTx() - 1)/minBytesPerTx(); + const int totalLength = m_payloadLength + getECCBytesForLength(m_payloadLength); + const int totalTxs = (totalLength + minBytesPerTx() - 1)/minBytesPerTx(); m_rx->spectrumHistoryFixed.resize(totalTxs*maxFramesPerTx()); } else { @@ -520,6 +518,14 @@ GGWave::GGWave(const Parameters & parameters) : m_rx->sampleAmplitudeAverage.resize(m_samplesPerFrame); m_rx->sampleAmplitudeHistory.resize(kMaxSpectrumHistory); } + + for (auto & s : m_rx->sampleAmplitudeHistory) { + s.resize(m_samplesPerFrame); + } + + for (auto & s : m_rx->spectrumHistoryFixed) { + s.resize(m_samplesPerFrame); + } } if (m_isTxEnabled) { @@ -536,6 +542,16 @@ GGWave::GGWave(const Parameters & parameters) : // m_tx->waveformTones; } + // pre-allocate Reed-Solomon memory buffers + { + const auto maxLength = m_isFixedPayloadLength ? m_payloadLength : kMaxLengthVariable; + + if (m_isFixedPayloadLength == false) { + m_workRSLength.resize(RS::ReedSolomon::getWorkSize_bytes(1, m_encodedDataOffset - 1)); + } + m_workRSData.resize(RS::ReedSolomon::getWorkSize_bytes(maxLength, getECCBytesForLength(maxLength))); + } + if (m_needResampling) { m_resampler = std::unique_ptr(new Resampler()); } @@ -546,12 +562,12 @@ GGWave::GGWave(const Parameters & parameters) : GGWave::~GGWave() { } -bool GGWave::init(const std::string & text, const int volume) { - return init((int) text.size(), text.data(), getDefaultTxProtocol(), volume); +bool GGWave::init(const char * text, const int volume) { + return init(strlen(text), text, getDefaultTxProtocol(), volume); } -bool GGWave::init(const std::string & text, const TxProtocol & txProtocol, const int volume) { - return init((int) text.size(), text.data(), txProtocol, volume); +bool GGWave::init(const char * text, const TxProtocol & txProtocol, const int volume) { + return init(strlen(text), text, txProtocol, volume); } bool GGWave::init(int dataSize, const char * dataBuffer, const int volume) { @@ -564,32 +580,30 @@ bool GGWave::init(int dataSize, const char * dataBuffer, const TxProtocol & txPr return false; } - auto maxLength = m_isFixedPayloadLength ? m_payloadLength : kMaxLengthVarible; - if (dataSize > maxLength) { - ggprintf("Truncating data from %d to %d bytes\n", dataSize, maxLength); - dataSize = maxLength; - } - - if (volume < 0 || volume > 100) { - ggprintf("Invalid volume: %d\n", volume); - return false; - } - // Tx if (m_isTxEnabled) { + const auto maxLength = m_isFixedPayloadLength ? m_payloadLength : kMaxLengthVariable; + if (dataSize > maxLength) { + ggprintf("Truncating data from %d to %d bytes\n", dataSize, maxLength); + dataSize = maxLength; + } + + if (volume < 0 || volume > 100) { + ggprintf("Invalid volume: %d\n", volume); + return false; + } + m_tx->txProtocol = txProtocol; m_tx->txDataLength = dataSize; m_tx->sendVolume = ((double)(volume))/100.0f; - const uint8_t * text = reinterpret_cast(dataBuffer); - m_tx->hasNewTxData = false; std::fill(m_tx->txData.begin(), m_tx->txData.end(), 0); std::fill(m_dataEncoded.begin(), m_dataEncoded.end(), 0); if (m_tx->txDataLength > 0) { m_tx->txData[0] = m_tx->txDataLength; - for (int i = 0; i < m_tx->txDataLength; ++i) m_tx->txData[i + 1] = text[i]; + for (int i = 0; i < m_tx->txDataLength; ++i) m_tx->txData[i + 1] = dataBuffer[i]; m_tx->hasNewTxData = true; } @@ -612,7 +626,6 @@ bool GGWave::init(int dataSize, const char * dataBuffer, const TxProtocol & txPr std::fill(m_rx->sampleSpectrum.begin(), m_rx->sampleSpectrum.end(), 0); std::fill(m_rx->sampleAmplitude.begin(), m_rx->sampleAmplitude.end(), 0); for (auto & s : m_rx->sampleAmplitudeHistory) { - s.resize(m_samplesPerFrame); std::fill(s.begin(), s.end(), 0); } @@ -624,7 +637,6 @@ bool GGWave::init(int dataSize, const char * dataBuffer, const TxProtocol & txPr } for (auto & s : m_rx->spectrumHistoryFixed) { - s.resize(m_samplesPerFrame); std::fill(s.begin(), s.end(), 0); } } @@ -707,12 +719,12 @@ bool GGWave::encode(const CBWaveformOut & cbWaveformOut) { int totalDataFrames = ((totalBytes + m_tx->txProtocol.bytesPerTx - 1)/m_tx->txProtocol.bytesPerTx)*m_tx->txProtocol.framesPerTx; if (m_isFixedPayloadLength == false) { - RS::ReedSolomon rsLength(1, m_encodedDataOffset - 1); + RS::ReedSolomon rsLength(1, m_encodedDataOffset - 1, m_workRSLength.data()); rsLength.Encode(m_tx->txData.data(), m_dataEncoded.data()); } // first byte of m_tx->txData contains the length of the payload, so we skip it: - RS::ReedSolomon rsData = RS::ReedSolomon(m_tx->txDataLength, nECCBytesPerTx); + RS::ReedSolomon rsData = RS::ReedSolomon(m_tx->txDataLength, nECCBytesPerTx, m_workRSData.data()); rsData.Encode(m_tx->txData.data() + 1, m_dataEncoded.data() + m_encodedDataOffset); const float factor = m_sampleRate/m_sampleRateOut; @@ -723,7 +735,7 @@ bool GGWave::encode(const CBWaveformOut & cbWaveformOut) { while (m_tx->hasNewTxData) { std::fill(m_tx->outputBlock.begin(), m_tx->outputBlock.end(), 0.0f); - std::uint16_t nFreq = 0; + uint16_t nFreq = 0; m_tx->waveformTones.push_back({}); if (frameId < m_nMarkerFrames) { @@ -1385,7 +1397,7 @@ void GGWave::decode_variable() { } if (itx*rxProtocol.bytesPerTx > m_encodedDataOffset && knownLength == false) { - RS::ReedSolomon rsLength(1, m_encodedDataOffset - 1); + RS::ReedSolomon rsLength(1, m_encodedDataOffset - 1, m_workRSLength.data()); if ((rsLength.Decode(m_dataEncoded.data(), m_rx->rxData.data()) == 0) && (m_rx->rxData[0] > 0 && m_rx->rxData[0] <= 140)) { knownLength = true; decodedLength = m_rx->rxData[0]; @@ -1411,14 +1423,12 @@ void GGWave::decode_variable() { } if (knownLength) { - RS::ReedSolomon rsData(decodedLength, ::getECCBytesForLength(decodedLength)); + RS::ReedSolomon rsData(decodedLength, ::getECCBytesForLength(decodedLength), m_workRSData.data()); if (rsData.Decode(m_dataEncoded.data() + m_encodedDataOffset, m_rx->rxData.data()) == 0) { if (m_rx->rxData[0] != 0) { - std::string s((char *) m_rx->rxData.data(), decodedLength); - ggprintf("Decoded length = %d, protocol = '%s' (%d)\n", decodedLength, rxProtocol.name, rxProtocolId); - ggprintf("Received sound data successfully: '%s'\n", s.c_str()); + ggprintf("Received sound data successfully: '%s'\n", m_rx->rxData.data()); isValid = true; m_rx->hasNewRxData = true; @@ -1502,7 +1512,7 @@ void GGWave::decode_variable() { // max recieve duration m_rx->recvDuration_frames = 2*m_nMarkerFrames + - maxFramesPerTx()*((kMaxLengthVarible + ::getECCBytesForLength(kMaxLengthVarible))/minBytesPerTx() + 1); + maxFramesPerTx()*((kMaxLengthVariable + ::getECCBytesForLength(kMaxLengthVariable))/minBytesPerTx() + 1); m_rx->nMarkersSuccess = 0; m_rx->framesToRecord = m_rx->recvDuration_frames; @@ -1673,7 +1683,7 @@ void GGWave::decode_fixed() { } if (detectedSignal) { - RS::ReedSolomon rsData(m_payloadLength, getECCBytesForLength(m_payloadLength)); + RS::ReedSolomon rsData(m_payloadLength, getECCBytesForLength(m_payloadLength), m_workRSData.data()); for (int j = 0; j < totalLength; ++j) { m_dataEncoded[j] = (detectedBins[2*j + 1] << 4) + detectedBins[2*j + 0]; diff --git a/src/reed-solomon/rs.hpp b/src/reed-solomon/rs.hpp index 202cdd5..3e9903d 100644 --- a/src/reed-solomon/rs.hpp +++ b/src/reed-solomon/rs.hpp @@ -12,7 +12,6 @@ #include #include #include -#include namespace RS { @@ -24,12 +23,26 @@ public: const uint8_t msg_length; const uint8_t ecc_length; + uint8_t * heap_memory = nullptr; uint8_t * generator_cache = nullptr; - bool generator_cached = false; + bool owns_heap_memory = false; + bool generator_cached = false; - ReedSolomon(uint8_t msg_length_p, uint8_t ecc_length_p) : + // used to pre-allocate a memory buffer for the Reed-Solomon class in order to avoid memory allocations + static size_t getWorkSize_bytes(uint8_t msg_length, uint8_t ecc_length) { + return ecc_length + 1 + MSG_CNT * msg_length + POLY_CNT * ecc_length * 2; + } + + ReedSolomon(uint8_t msg_length_p, uint8_t ecc_length_p, uint8_t * heap_memory_p = nullptr) : msg_length(msg_length_p), ecc_length(ecc_length_p) { - generator_cache = new uint8_t[ecc_length + 1]; + if (heap_memory_p) { + heap_memory = heap_memory_p; + owns_heap_memory = false; + } else { + heap_memory = new uint8_t[getWorkSize_bytes(msg_length, ecc_length)]; + owns_heap_memory = true; + } + generator_cache = heap_memory; const uint8_t enc_len = msg_length + ecc_length; const uint8_t poly_len = ecc_length * 2; @@ -59,7 +72,9 @@ public: } ~ReedSolomon() { - delete [] generator_cache; + if (owns_heap_memory) { + delete[] heap_memory; + } // Dummy destructor, gcc-generated one crashes programm memory = NULL; } @@ -75,8 +90,7 @@ public: //this->memory = stack_memory; // gg : allocation is now on the heap - std::vector stack_memory(MSG_CNT * msg_length + POLY_CNT * ecc_length * 2); - this->memory = stack_memory.data(); + this->memory = heap_memory + ecc_length + 1; const uint8_t* src_ptr = (const uint8_t*) src; uint8_t* dst_ptr = (uint8_t*) dst; @@ -155,8 +169,7 @@ public: //this->memory = stack_memory; // gg : allocation is now on the heap - std::vector stack_memory(MSG_CNT * msg_length + POLY_CNT * ecc_length * 2); - this->memory = stack_memory.data(); + this->memory = heap_memory + ecc_length + 1; Poly *msg_in = &polynoms[ID_MSG_IN]; Poly *msg_out = &polynoms[ID_MSG_OUT]; diff --git a/tests/test-ggwave.cpp b/tests/test-ggwave.cpp index e5e352c..058fa36 100644 --- a/tests/test-ggwave.cpp +++ b/tests/test-ggwave.cpp @@ -209,7 +209,7 @@ int main(int argc, char ** argv) { std::string payload = "hello"; - CHECK(instance.init(payload)); + CHECK(instance.init(payload.c_str())); // data CHECK_F(instance.init(-1, "asd")); @@ -241,7 +241,7 @@ int main(int argc, char ** argv) { parameters.sampleRateOut = srInp; GGWave instanceOut(parameters); - instanceOut.init(payload, instanceOut.getTxProtocol(GGWAVE_TX_PROTOCOL_DT_FASTEST), 25); + instanceOut.init(payload.c_str(), instanceOut.getTxProtocol(GGWAVE_TX_PROTOCOL_DT_FASTEST), 25); auto expectedSize = instanceOut.encodeSize_samples(); instanceOut.encode(kCBWaveformOut.at(parameters.sampleFormatOut)); printf("Expected = %d, actual = %d\n", expectedSize, nSamples);