import random
from collections import deque


class ReplayBuffer(object):
    """ReplayBuffer for DRL
    Accept for tuple, (state, act, reward, next_state)"""

    def __init__(self, maxlen, data=list()):
        self._maxlen = maxlen
        self._buffer = deque(data, maxlen=maxlen)

    def __len__(self):
        return len(self._buffer)

    def sample(self, batch_size=32):
        if len(self._buffer) <= batch_size:
            return list(self._buffer)
        else:
            return random.sample(self._buffer, batch_size)

    def append(self, data):
        self._buffer.appendleft(data)

    def extend(self, data_list):
        self._buffer.extendleft(data_list)

    def clear(self):
        self._buffer.clear()

    @property
    def content(self):
        return list(self._buffer)

    @property
    def maxlen(self):
        return self._maxlen

    def __repr__(self):
        return self._buffer.__repr__()

    def __str__(self):
        return self._buffer.__str__()


def test_buffer():
    data = [(1, 2, 3)] * 5 + [(2, 3, 4)] * 3
    buffer = ReplayBuffer(7, data)
    print(buffer.content)
    buffer.append((3, 4, 5))
    print(buffer.content)
    buffer.extend([(6, 4, 2)] * 2)
    print(buffer.content, len(buffer))
    print(buffer.sample(3), len(buffer))
    print(buffer.sample(7))
    print(buffer.sample(8))
    buffer.clear()
    print(buffer.content, len(buffer))

results matching ""

    No results matching ""