diff --git a/pdu/bufpool.go b/pdu/bufpool.go index eb0602f..c50d947 100644 --- a/pdu/bufpool.go +++ b/pdu/bufpool.go @@ -1,48 +1,58 @@ package pdu import ( - "sync" + "sync" ) type BufferPoolManager struct { - pools map[uint]*sync.Pool - mu sync.Mutex + pools map[uint]*sync.Pool + mu sync.RWMutex } func NewBufferPoolManager() *BufferPoolManager { - return &BufferPoolManager{ - pools: make(map[uint]*sync.Pool), - } + return &BufferPoolManager{ + pools: make(map[uint]*sync.Pool), + } } -func (bpm *BufferPoolManager) GetBuffer(size uint) *([]uint8) { - bpm.mu.Lock() - pool, exists := bpm.pools[size] - if !exists { - pool = &sync.Pool{ - New: func() interface{} { - buf := make([]uint8, size) - return &buf - }, - } - bpm.pools[size] = pool - } - bpm.mu.Unlock() - return pool.Get().(*[]uint8) +func (bpm *BufferPoolManager) GetBuffer(size uint) *[]uint8 { + bpm.mu.RLock() + pool, exists := bpm.pools[size] + bpm.mu.RUnlock() + + if !exists { + bpm.mu.Lock() + // Double-check if another goroutine added the pool while we were waiting + pool, exists = bpm.pools[size] + if !exists { + pool = &sync.Pool{ + New: func() interface{} { + buf := make([]uint8, size) + return &buf + }, + } + bpm.pools[size] = pool + } + bpm.mu.Unlock() + } + + return pool.Get().(*[]uint8) } -func (bpm *BufferPoolManager) PutBuffer(buf *([]uint8)) { - size := uint(len(*buf)) - bpm.mu.Lock() - pool, exists := bpm.pools[size] - if !exists { - bpm.mu.Unlock() - return - } - bpm.mu.Unlock() - - for i := range *buf { - (*buf)[i] = 0 - } - pool.Put(buf) +func (bpm *BufferPoolManager) PutBuffer(buf *[]uint8) { + size := uint(len(*buf)) + bpm.mu.RLock() + pool, exists := bpm.pools[size] + bpm.mu.RUnlock() + + if !exists { + return + } + + // Clear buffer + for i := range *buf { + (*buf)[i] = 0 + } + + pool.Put(buf) } diff --git a/pdu/bufpool_test.go b/pdu/bufpool_test.go index 9ed5f82..db5c229 100644 --- a/pdu/bufpool_test.go +++ b/pdu/bufpool_test.go @@ -6,10 +6,7 @@ import ( ) func TestRetrieveBufferOfRequestedSize(t *testing.T) { - bpm := &BufferPoolManager{ - pools: make(map[uint]*sync.Pool), - mu: sync.Mutex{}, - } + bpm := NewBufferPoolManager() size := 1024 buffer := bpm.GetBuffer(uint(size)) @@ -24,10 +21,7 @@ func TestRetrieveBufferOfRequestedSize(t *testing.T) { } func TestRequestBufferSizeZero(t *testing.T) { - bpm := &BufferPoolManager{ - pools: make(map[uint]*sync.Pool), - mu: sync.Mutex{}, - } + bpm := NewBufferPoolManager() size := 0 buffer := bpm.GetBuffer(uint(size)) @@ -42,10 +36,7 @@ func TestRequestBufferSizeZero(t *testing.T) { } func TestConcurrentAccessToBufferPool(t *testing.T) { - bpm := &BufferPoolManager{ - pools: make(map[uint]*sync.Pool), - mu: sync.Mutex{}, - } + bpm := NewBufferPoolManager() size := 1024 var wg sync.WaitGroup @@ -69,10 +60,7 @@ func TestConcurrentAccessToBufferPool(t *testing.T) { } func TestGetBufferLockUnlock(t *testing.T) { - bpm := &BufferPoolManager{ - pools: make(map[uint]*sync.Pool), - mu: sync.Mutex{}, - } + bpm := NewBufferPoolManager() size := 1024 buffer := bpm.GetBuffer(uint(size)) @@ -87,10 +75,7 @@ func TestGetBufferLockUnlock(t *testing.T) { } func TestVerifyPoolCreationForNewSizes(t *testing.T) { - bpm := &BufferPoolManager{ - pools: make(map[uint]*sync.Pool), - mu: sync.Mutex{}, - } + bpm := NewBufferPoolManager() size := 512 buffer := bpm.GetBuffer(uint(size)) @@ -105,10 +90,7 @@ func TestVerifyPoolCreationForNewSizes(t *testing.T) { } func TestBufferPoolManagerGetBuffer(t *testing.T) { - bpm := &BufferPoolManager{ - pools: make(map[uint]*sync.Pool), - mu: sync.Mutex{}, - } + bpm := NewBufferPoolManager() size := 1024 buffer := bpm.GetBuffer(uint(size)) @@ -123,10 +105,7 @@ func TestBufferPoolManagerGetBuffer(t *testing.T) { } func TestGetBufferWithMultipleSizes(t *testing.T) { - bpm := &BufferPoolManager{ - pools: make(map[uint]*sync.Pool), - mu: sync.Mutex{}, - } + bpm := NewBufferPoolManager() sizes := []int{512, 1024, 2048} for _, size := range sizes { @@ -141,3 +120,28 @@ func TestGetBufferWithMultipleSizes(t *testing.T) { } } } + +func TestGetBufferIsAlwaysZero(t *testing.T) { + bpm := NewBufferPoolManager() + + var size uint = 1024*64 + for i := 0; i < 1000; i++ { + buffer := bpm.GetBuffer(size) + + if buffer == nil { + t.Fatalf("Expected buffer for size %d, got nil", size) + } + + if uint(len(*buffer)) != size { + t.Errorf("Expected buffer size %d, got %d", size, len(*buffer)) + } + + for _, b := range *buffer { + if b != 0 { + t.Errorf("Expected buffer to be zero, got %d", b) + } + } + + bpm.PutBuffer(buffer) + } +} diff --git a/pdu/global.go b/pdu/global.go index 6203035..d06cdb6 100644 --- a/pdu/global.go +++ b/pdu/global.go @@ -1,8 +1,3 @@ package pdu -import "sync" - -var ByteBufferPool = &BufferPoolManager{ - pools: make(map[uint]*sync.Pool), - mu: sync.Mutex{}, -} +var ByteBufferPool = NewBufferPoolManager()