ecc : reduce memory allocations in reed-solomon lib

This commit is contained in:
Georgi Gerganov
2022-05-29 11:47:39 +03:00
parent aea3096b85
commit 1090fdb397
4 changed files with 91 additions and 67 deletions

View File

@@ -298,26 +298,25 @@ extern "C" {
#include <functional> #include <functional>
#include <vector> #include <vector>
#include <map> #include <map>
#include <string>
#include <memory> #include <memory>
class GGWave { class GGWave {
public: public:
static constexpr auto kSampleRateMin = 1000.0f; static constexpr auto kSampleRateMin = 1000.0f;
static constexpr auto kSampleRateMax = 96000.0f; static constexpr auto kSampleRateMax = 96000.0f;
static constexpr auto kDefaultSampleRate = 48000.0f; static constexpr auto kDefaultSampleRate = 48000.0f;
static constexpr auto kDefaultSamplesPerFrame = 1024; static constexpr auto kDefaultSamplesPerFrame = 1024;
static constexpr auto kDefaultVolume = 10; static constexpr auto kDefaultVolume = 10;
static constexpr auto kDefaultSoundMarkerThreshold = 3.0f; static constexpr auto kDefaultSoundMarkerThreshold = 3.0f;
static constexpr auto kDefaultMarkerFrames = 16; static constexpr auto kDefaultMarkerFrames = 16;
static constexpr auto kDefaultEncodedDataOffset = 3; static constexpr auto kDefaultEncodedDataOffset = 3;
static constexpr auto kMaxSamplesPerFrame = 1024; static constexpr auto kMaxSamplesPerFrame = 1024;
static constexpr auto kMaxDataBits = 256; static constexpr auto kMaxDataBits = 256;
static constexpr auto kMaxDataSize = 256; static constexpr auto kMaxDataSize = 256;
static constexpr auto kMaxLengthVarible = 140; static constexpr auto kMaxLengthVariable = 140;
static constexpr auto kMaxLengthFixed = 16; static constexpr auto kMaxLengthFixed = 16;
static constexpr auto kMaxSpectrumHistory = 4; static constexpr auto kMaxSpectrumHistory = 4;
static constexpr auto kMaxRecordedFrames = 2048; static constexpr auto kMaxRecordedFrames = 2048;
using Parameters = ggwave_Parameters; using Parameters = ggwave_Parameters;
using SampleFormat = ggwave_SampleFormat; using SampleFormat = ggwave_SampleFormat;
@@ -361,14 +360,14 @@ public:
double duration_ms; double duration_ms;
}; };
using Tones = std::vector<ToneData>; using Tones = std::vector<ToneData>;
using WaveformTones = std::vector<Tones>; using WaveformTones = std::vector<Tones>;
using AmplitudeData = std::vector<float>; using AmplitudeData = std::vector<float>;
using AmplitudeDataI16 = std::vector<int16_t>; using AmplitudeDataI16 = std::vector<int16_t>;
using SpectrumData = std::vector<float>; using SpectrumData = std::vector<float>;
using RecordedData = std::vector<float>; using RecordedData = std::vector<float>;
using TxRxData = std::vector<std::uint8_t>; using TxRxData = std::vector<uint8_t>;
using CBWaveformOut = std::function<void(const void * data, uint32_t nBytes)>; using CBWaveformOut = std::function<void(const void * data, uint32_t nBytes)>;
using CBWaveformInp = std::function<uint32_t(void * data, uint32_t nMaxBytes)>; using CBWaveformInp = std::function<uint32_t(void * data, uint32_t nMaxBytes)>;
@@ -394,8 +393,8 @@ public:
// //
// returns false upon invalid parameters or failure to initialize // returns false upon invalid parameters or failure to initialize
// //
bool init(const std::string & text, const int volume = kDefaultVolume); bool init(const char * text, const int volume = kDefaultVolume);
bool init(const std::string & text, const TxProtocol & txProtocol, const int volume = kDefaultVolume); bool init(const char * text, const TxProtocol & txProtocol, const int volume = kDefaultVolume);
bool init(int dataSize, const char * dataBuffer, const int volume = kDefaultVolume); bool init(int dataSize, const char * dataBuffer, const int volume = kDefaultVolume);
bool init(int dataSize, const char * dataBuffer, const TxProtocol & txProtocol, const int volume = kDefaultVolume); bool init(int dataSize, const char * dataBuffer, const TxProtocol & txProtocol, const int volume = kDefaultVolume);
@@ -576,6 +575,8 @@ private:
// common // common
TxRxData m_dataEncoded; TxRxData m_dataEncoded;
TxRxData m_workRSLength; // Reed-Solomon work buffers
TxRxData m_workRSData;
// Impl // Impl
struct Rx; struct Rx;

View File

@@ -4,8 +4,6 @@
#include <chrono> #include <chrono>
#include <cmath> #include <cmath>
#include <algorithm>
#include <stdexcept>
#include <map> #include <map>
//#include <random> //#include <random>
@@ -493,10 +491,10 @@ GGWave::GGWave(const Parameters & parameters) :
m_rx->fftInp.resize(m_samplesPerFrame); m_rx->fftInp.resize(m_samplesPerFrame);
m_rx->fftOut.resize(2*m_samplesPerFrame); m_rx->fftOut.resize(2*m_samplesPerFrame);
m_rx->sampleSpectrum.resize (m_samplesPerFrame); m_rx->sampleSpectrum.resize(m_samplesPerFrame);
m_rx->sampleAmplitude.resize (m_needResampling ? m_samplesPerFrame + 128 : m_samplesPerFrame); // small extra space because sometimes resampling needs a few more samples m_rx->sampleAmplitude.resize(m_needResampling ? m_samplesPerFrame + 128 : m_samplesPerFrame); // small extra space because sometimes resampling needs a few more samples
m_rx->sampleAmplitudeResampled.resize(m_needResampling ? 8*m_samplesPerFrame : m_samplesPerFrame); // min input sampling rate is 0.125*m_sampleRate m_rx->sampleAmplitudeResampled.resize(m_needResampling ? 8*m_samplesPerFrame : m_samplesPerFrame); // min input sampling rate is 0.125*m_sampleRate
m_rx->sampleAmplitudeTmp.resize (m_needResampling ? 8*m_samplesPerFrame*m_sampleSizeBytesInp : m_samplesPerFrame*m_sampleSizeBytesInp); m_rx->sampleAmplitudeTmp.resize(m_needResampling ? 8*m_samplesPerFrame*m_sampleSizeBytesInp : m_samplesPerFrame*m_sampleSizeBytesInp);
m_rx->rxData.resize(kMaxDataSize); m_rx->rxData.resize(kMaxDataSize);
@@ -510,8 +508,8 @@ GGWave::GGWave(const Parameters & parameters) :
return; return;
} }
int totalLength = m_payloadLength + getECCBytesForLength(m_payloadLength); const int totalLength = m_payloadLength + getECCBytesForLength(m_payloadLength);
int totalTxs = (totalLength + minBytesPerTx() - 1)/minBytesPerTx(); const int totalTxs = (totalLength + minBytesPerTx() - 1)/minBytesPerTx();
m_rx->spectrumHistoryFixed.resize(totalTxs*maxFramesPerTx()); m_rx->spectrumHistoryFixed.resize(totalTxs*maxFramesPerTx());
} else { } else {
@@ -520,6 +518,14 @@ GGWave::GGWave(const Parameters & parameters) :
m_rx->sampleAmplitudeAverage.resize(m_samplesPerFrame); m_rx->sampleAmplitudeAverage.resize(m_samplesPerFrame);
m_rx->sampleAmplitudeHistory.resize(kMaxSpectrumHistory); m_rx->sampleAmplitudeHistory.resize(kMaxSpectrumHistory);
} }
for (auto & s : m_rx->sampleAmplitudeHistory) {
s.resize(m_samplesPerFrame);
}
for (auto & s : m_rx->spectrumHistoryFixed) {
s.resize(m_samplesPerFrame);
}
} }
if (m_isTxEnabled) { if (m_isTxEnabled) {
@@ -536,6 +542,16 @@ GGWave::GGWave(const Parameters & parameters) :
// m_tx->waveformTones; // m_tx->waveformTones;
} }
// pre-allocate Reed-Solomon memory buffers
{
const auto maxLength = m_isFixedPayloadLength ? m_payloadLength : kMaxLengthVariable;
if (m_isFixedPayloadLength == false) {
m_workRSLength.resize(RS::ReedSolomon::getWorkSize_bytes(1, m_encodedDataOffset - 1));
}
m_workRSData.resize(RS::ReedSolomon::getWorkSize_bytes(maxLength, getECCBytesForLength(maxLength)));
}
if (m_needResampling) { if (m_needResampling) {
m_resampler = std::unique_ptr<Resampler>(new Resampler()); m_resampler = std::unique_ptr<Resampler>(new Resampler());
} }
@@ -546,12 +562,12 @@ GGWave::GGWave(const Parameters & parameters) :
GGWave::~GGWave() { GGWave::~GGWave() {
} }
bool GGWave::init(const std::string & text, const int volume) { bool GGWave::init(const char * text, const int volume) {
return init((int) text.size(), text.data(), getDefaultTxProtocol(), volume); return init(strlen(text), text, getDefaultTxProtocol(), volume);
} }
bool GGWave::init(const std::string & text, const TxProtocol & txProtocol, const int volume) { bool GGWave::init(const char * text, const TxProtocol & txProtocol, const int volume) {
return init((int) text.size(), text.data(), txProtocol, volume); return init(strlen(text), text, txProtocol, volume);
} }
bool GGWave::init(int dataSize, const char * dataBuffer, const int volume) { bool GGWave::init(int dataSize, const char * dataBuffer, const int volume) {
@@ -564,32 +580,30 @@ bool GGWave::init(int dataSize, const char * dataBuffer, const TxProtocol & txPr
return false; return false;
} }
auto maxLength = m_isFixedPayloadLength ? m_payloadLength : kMaxLengthVarible;
if (dataSize > maxLength) {
ggprintf("Truncating data from %d to %d bytes\n", dataSize, maxLength);
dataSize = maxLength;
}
if (volume < 0 || volume > 100) {
ggprintf("Invalid volume: %d\n", volume);
return false;
}
// Tx // Tx
if (m_isTxEnabled) { if (m_isTxEnabled) {
const auto maxLength = m_isFixedPayloadLength ? m_payloadLength : kMaxLengthVariable;
if (dataSize > maxLength) {
ggprintf("Truncating data from %d to %d bytes\n", dataSize, maxLength);
dataSize = maxLength;
}
if (volume < 0 || volume > 100) {
ggprintf("Invalid volume: %d\n", volume);
return false;
}
m_tx->txProtocol = txProtocol; m_tx->txProtocol = txProtocol;
m_tx->txDataLength = dataSize; m_tx->txDataLength = dataSize;
m_tx->sendVolume = ((double)(volume))/100.0f; m_tx->sendVolume = ((double)(volume))/100.0f;
const uint8_t * text = reinterpret_cast<const uint8_t *>(dataBuffer);
m_tx->hasNewTxData = false; m_tx->hasNewTxData = false;
std::fill(m_tx->txData.begin(), m_tx->txData.end(), 0); std::fill(m_tx->txData.begin(), m_tx->txData.end(), 0);
std::fill(m_dataEncoded.begin(), m_dataEncoded.end(), 0); std::fill(m_dataEncoded.begin(), m_dataEncoded.end(), 0);
if (m_tx->txDataLength > 0) { if (m_tx->txDataLength > 0) {
m_tx->txData[0] = m_tx->txDataLength; m_tx->txData[0] = m_tx->txDataLength;
for (int i = 0; i < m_tx->txDataLength; ++i) m_tx->txData[i + 1] = text[i]; for (int i = 0; i < m_tx->txDataLength; ++i) m_tx->txData[i + 1] = dataBuffer[i];
m_tx->hasNewTxData = true; m_tx->hasNewTxData = true;
} }
@@ -612,7 +626,6 @@ bool GGWave::init(int dataSize, const char * dataBuffer, const TxProtocol & txPr
std::fill(m_rx->sampleSpectrum.begin(), m_rx->sampleSpectrum.end(), 0); std::fill(m_rx->sampleSpectrum.begin(), m_rx->sampleSpectrum.end(), 0);
std::fill(m_rx->sampleAmplitude.begin(), m_rx->sampleAmplitude.end(), 0); std::fill(m_rx->sampleAmplitude.begin(), m_rx->sampleAmplitude.end(), 0);
for (auto & s : m_rx->sampleAmplitudeHistory) { for (auto & s : m_rx->sampleAmplitudeHistory) {
s.resize(m_samplesPerFrame);
std::fill(s.begin(), s.end(), 0); std::fill(s.begin(), s.end(), 0);
} }
@@ -624,7 +637,6 @@ bool GGWave::init(int dataSize, const char * dataBuffer, const TxProtocol & txPr
} }
for (auto & s : m_rx->spectrumHistoryFixed) { for (auto & s : m_rx->spectrumHistoryFixed) {
s.resize(m_samplesPerFrame);
std::fill(s.begin(), s.end(), 0); std::fill(s.begin(), s.end(), 0);
} }
} }
@@ -707,12 +719,12 @@ bool GGWave::encode(const CBWaveformOut & cbWaveformOut) {
int totalDataFrames = ((totalBytes + m_tx->txProtocol.bytesPerTx - 1)/m_tx->txProtocol.bytesPerTx)*m_tx->txProtocol.framesPerTx; int totalDataFrames = ((totalBytes + m_tx->txProtocol.bytesPerTx - 1)/m_tx->txProtocol.bytesPerTx)*m_tx->txProtocol.framesPerTx;
if (m_isFixedPayloadLength == false) { if (m_isFixedPayloadLength == false) {
RS::ReedSolomon rsLength(1, m_encodedDataOffset - 1); RS::ReedSolomon rsLength(1, m_encodedDataOffset - 1, m_workRSLength.data());
rsLength.Encode(m_tx->txData.data(), m_dataEncoded.data()); rsLength.Encode(m_tx->txData.data(), m_dataEncoded.data());
} }
// first byte of m_tx->txData contains the length of the payload, so we skip it: // first byte of m_tx->txData contains the length of the payload, so we skip it:
RS::ReedSolomon rsData = RS::ReedSolomon(m_tx->txDataLength, nECCBytesPerTx); RS::ReedSolomon rsData = RS::ReedSolomon(m_tx->txDataLength, nECCBytesPerTx, m_workRSData.data());
rsData.Encode(m_tx->txData.data() + 1, m_dataEncoded.data() + m_encodedDataOffset); rsData.Encode(m_tx->txData.data() + 1, m_dataEncoded.data() + m_encodedDataOffset);
const float factor = m_sampleRate/m_sampleRateOut; const float factor = m_sampleRate/m_sampleRateOut;
@@ -723,7 +735,7 @@ bool GGWave::encode(const CBWaveformOut & cbWaveformOut) {
while (m_tx->hasNewTxData) { while (m_tx->hasNewTxData) {
std::fill(m_tx->outputBlock.begin(), m_tx->outputBlock.end(), 0.0f); std::fill(m_tx->outputBlock.begin(), m_tx->outputBlock.end(), 0.0f);
std::uint16_t nFreq = 0; uint16_t nFreq = 0;
m_tx->waveformTones.push_back({}); m_tx->waveformTones.push_back({});
if (frameId < m_nMarkerFrames) { if (frameId < m_nMarkerFrames) {
@@ -1385,7 +1397,7 @@ void GGWave::decode_variable() {
} }
if (itx*rxProtocol.bytesPerTx > m_encodedDataOffset && knownLength == false) { if (itx*rxProtocol.bytesPerTx > m_encodedDataOffset && knownLength == false) {
RS::ReedSolomon rsLength(1, m_encodedDataOffset - 1); RS::ReedSolomon rsLength(1, m_encodedDataOffset - 1, m_workRSLength.data());
if ((rsLength.Decode(m_dataEncoded.data(), m_rx->rxData.data()) == 0) && (m_rx->rxData[0] > 0 && m_rx->rxData[0] <= 140)) { if ((rsLength.Decode(m_dataEncoded.data(), m_rx->rxData.data()) == 0) && (m_rx->rxData[0] > 0 && m_rx->rxData[0] <= 140)) {
knownLength = true; knownLength = true;
decodedLength = m_rx->rxData[0]; decodedLength = m_rx->rxData[0];
@@ -1411,14 +1423,12 @@ void GGWave::decode_variable() {
} }
if (knownLength) { if (knownLength) {
RS::ReedSolomon rsData(decodedLength, ::getECCBytesForLength(decodedLength)); RS::ReedSolomon rsData(decodedLength, ::getECCBytesForLength(decodedLength), m_workRSData.data());
if (rsData.Decode(m_dataEncoded.data() + m_encodedDataOffset, m_rx->rxData.data()) == 0) { if (rsData.Decode(m_dataEncoded.data() + m_encodedDataOffset, m_rx->rxData.data()) == 0) {
if (m_rx->rxData[0] != 0) { if (m_rx->rxData[0] != 0) {
std::string s((char *) m_rx->rxData.data(), decodedLength);
ggprintf("Decoded length = %d, protocol = '%s' (%d)\n", decodedLength, rxProtocol.name, rxProtocolId); ggprintf("Decoded length = %d, protocol = '%s' (%d)\n", decodedLength, rxProtocol.name, rxProtocolId);
ggprintf("Received sound data successfully: '%s'\n", s.c_str()); ggprintf("Received sound data successfully: '%s'\n", m_rx->rxData.data());
isValid = true; isValid = true;
m_rx->hasNewRxData = true; m_rx->hasNewRxData = true;
@@ -1502,7 +1512,7 @@ void GGWave::decode_variable() {
// max recieve duration // max recieve duration
m_rx->recvDuration_frames = m_rx->recvDuration_frames =
2*m_nMarkerFrames + 2*m_nMarkerFrames +
maxFramesPerTx()*((kMaxLengthVarible + ::getECCBytesForLength(kMaxLengthVarible))/minBytesPerTx() + 1); maxFramesPerTx()*((kMaxLengthVariable + ::getECCBytesForLength(kMaxLengthVariable))/minBytesPerTx() + 1);
m_rx->nMarkersSuccess = 0; m_rx->nMarkersSuccess = 0;
m_rx->framesToRecord = m_rx->recvDuration_frames; m_rx->framesToRecord = m_rx->recvDuration_frames;
@@ -1673,7 +1683,7 @@ void GGWave::decode_fixed() {
} }
if (detectedSignal) { if (detectedSignal) {
RS::ReedSolomon rsData(m_payloadLength, getECCBytesForLength(m_payloadLength)); RS::ReedSolomon rsData(m_payloadLength, getECCBytesForLength(m_payloadLength), m_workRSData.data());
for (int j = 0; j < totalLength; ++j) { for (int j = 0; j < totalLength; ++j) {
m_dataEncoded[j] = (detectedBins[2*j + 1] << 4) + detectedBins[2*j + 0]; m_dataEncoded[j] = (detectedBins[2*j + 1] << 4) + detectedBins[2*j + 0];

View File

@@ -12,7 +12,6 @@
#include <assert.h> #include <assert.h>
#include <string.h> #include <string.h>
#include <stdint.h> #include <stdint.h>
#include <vector>
namespace RS { namespace RS {
@@ -24,12 +23,26 @@ public:
const uint8_t msg_length; const uint8_t msg_length;
const uint8_t ecc_length; const uint8_t ecc_length;
uint8_t * heap_memory = nullptr;
uint8_t * generator_cache = nullptr; uint8_t * generator_cache = nullptr;
bool generator_cached = false; bool owns_heap_memory = false;
bool generator_cached = false;
ReedSolomon(uint8_t msg_length_p, uint8_t ecc_length_p) : // used to pre-allocate a memory buffer for the Reed-Solomon class in order to avoid memory allocations
static size_t getWorkSize_bytes(uint8_t msg_length, uint8_t ecc_length) {
return ecc_length + 1 + MSG_CNT * msg_length + POLY_CNT * ecc_length * 2;
}
ReedSolomon(uint8_t msg_length_p, uint8_t ecc_length_p, uint8_t * heap_memory_p = nullptr) :
msg_length(msg_length_p), ecc_length(ecc_length_p) { msg_length(msg_length_p), ecc_length(ecc_length_p) {
generator_cache = new uint8_t[ecc_length + 1]; if (heap_memory_p) {
heap_memory = heap_memory_p;
owns_heap_memory = false;
} else {
heap_memory = new uint8_t[getWorkSize_bytes(msg_length, ecc_length)];
owns_heap_memory = true;
}
generator_cache = heap_memory;
const uint8_t enc_len = msg_length + ecc_length; const uint8_t enc_len = msg_length + ecc_length;
const uint8_t poly_len = ecc_length * 2; const uint8_t poly_len = ecc_length * 2;
@@ -59,7 +72,9 @@ public:
} }
~ReedSolomon() { ~ReedSolomon() {
delete [] generator_cache; if (owns_heap_memory) {
delete[] heap_memory;
}
// Dummy destructor, gcc-generated one crashes programm // Dummy destructor, gcc-generated one crashes programm
memory = NULL; memory = NULL;
} }
@@ -75,8 +90,7 @@ public:
//this->memory = stack_memory; //this->memory = stack_memory;
// gg : allocation is now on the heap // gg : allocation is now on the heap
std::vector<uint8_t> stack_memory(MSG_CNT * msg_length + POLY_CNT * ecc_length * 2); this->memory = heap_memory + ecc_length + 1;
this->memory = stack_memory.data();
const uint8_t* src_ptr = (const uint8_t*) src; const uint8_t* src_ptr = (const uint8_t*) src;
uint8_t* dst_ptr = (uint8_t*) dst; uint8_t* dst_ptr = (uint8_t*) dst;
@@ -155,8 +169,7 @@ public:
//this->memory = stack_memory; //this->memory = stack_memory;
// gg : allocation is now on the heap // gg : allocation is now on the heap
std::vector<uint8_t> stack_memory(MSG_CNT * msg_length + POLY_CNT * ecc_length * 2); this->memory = heap_memory + ecc_length + 1;
this->memory = stack_memory.data();
Poly *msg_in = &polynoms[ID_MSG_IN]; Poly *msg_in = &polynoms[ID_MSG_IN];
Poly *msg_out = &polynoms[ID_MSG_OUT]; Poly *msg_out = &polynoms[ID_MSG_OUT];

View File

@@ -209,7 +209,7 @@ int main(int argc, char ** argv) {
std::string payload = "hello"; std::string payload = "hello";
CHECK(instance.init(payload)); CHECK(instance.init(payload.c_str()));
// data // data
CHECK_F(instance.init(-1, "asd")); CHECK_F(instance.init(-1, "asd"));
@@ -241,7 +241,7 @@ int main(int argc, char ** argv) {
parameters.sampleRateOut = srInp; parameters.sampleRateOut = srInp;
GGWave instanceOut(parameters); GGWave instanceOut(parameters);
instanceOut.init(payload, instanceOut.getTxProtocol(GGWAVE_TX_PROTOCOL_DT_FASTEST), 25); instanceOut.init(payload.c_str(), instanceOut.getTxProtocol(GGWAVE_TX_PROTOCOL_DT_FASTEST), 25);
auto expectedSize = instanceOut.encodeSize_samples(); auto expectedSize = instanceOut.encodeSize_samples();
instanceOut.encode(kCBWaveformOut.at(parameters.sampleFormatOut)); instanceOut.encode(kCBWaveformOut.at(parameters.sampleFormatOut));
printf("Expected = %d, actual = %d\n", expectedSize, nSamples); printf("Expected = %d, actual = %d\n", expectedSize, nSamples);