Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed Framework #327

Draft
wants to merge 29 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2567fe5
added demo notebooks
threewisemonkeys-as Aug 27, 2020
1aebe3c
Merge branch 'master' of https://github.com/SforAiDl/genrl
threewisemonkeys-as Aug 31, 2020
e3b8a8a
Merge branch 'master' of https://github.com/SforAiDl/genrl
threewisemonkeys-as Sep 3, 2020
4cba727
initial structure
threewisemonkeys-as Sep 3, 2020
3d233c4
added mp
threewisemonkeys-as Sep 6, 2020
1c504cc
add files
threewisemonkeys-as Sep 20, 2020
2c3298a
added new structure on rpc
threewisemonkeys-as Oct 1, 2020
73586d5
working structure
threewisemonkeys-as Oct 7, 2020
9ef6845
fixed integration bugs
threewisemonkeys-as Oct 7, 2020
072d545
removed unneccary files
threewisemonkeys-as Oct 7, 2020
64db1c1
added support for running from multiple scripts
threewisemonkeys-as Oct 8, 2020
4d57a06
added evaluate to trainer
threewisemonkeys-as Oct 8, 2020
f325429
added proxy getter
threewisemonkeys-as Oct 23, 2020
7ce19ec
added rpc backend option
threewisemonkeys-as Oct 23, 2020
cfba909
added logging to trainer
threewisemonkeys-as Oct 23, 2020
992a3a9
Added more options to trainer
threewisemonkeys-as Oct 23, 2020
bf1a50a
moved load weights to user
threewisemonkeys-as Oct 23, 2020
e2eef66
decreased number of eval its
threewisemonkeys-as Oct 23, 2020
837eb18
removed train wrapper
threewisemonkeys-as Oct 23, 2020
7fcbb23
removed loop to user fn
threewisemonkeys-as Oct 26, 2020
0002fa4
added example for secondary node
threewisemonkeys-as Oct 26, 2020
bebf50f
removed original exmpale
threewisemonkeys-as Oct 26, 2020
29bd1d6
removed fn
threewisemonkeys-as Oct 26, 2020
18536a2
shifted examples
threewisemonkeys-as Oct 29, 2020
8f859d6
shifted logger to base class
threewisemonkeys-as Oct 29, 2020
555e290
added on policy example
threewisemonkeys-as Oct 29, 2020
59e960c
removed temp example
threewisemonkeys-as Oct 29, 2020
8d5a8b6
got on policy distributed example to work
threewisemonkeys-as Oct 29, 2020
8030b2a
formatting
threewisemonkeys-as Oct 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions examples/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from genrl.distributed import (
Master,
ExperienceServer,
ParameterServer,
ActorNode,
LearnerNode,
WeightHolder,
)
from genrl.core import ReplayBuffer
from genrl.agents import DDPG
from genrl.trainers import DistributedTrainer
import gym
import argparse
import torch.multiprocessing as mp

parser = argparse.ArgumentParser()
parser.add_argument("-n", type=int)
args = parser.parse_args()

N_ACTORS = 2
BUFFER_SIZE = 10
MAX_ENV_STEPS = 100
TRAIN_STEPS = 10
BATCH_SIZE = 1


def collect_experience(agent, experience_server_rref):
obs = agent.env.reset()
done = False
for i in range(MAX_ENV_STEPS):
action = agent.select_action(obs)
next_obs, reward, done, info = agent.env.step(action)
experience_server_rref.rpc_sync().push((obs, action, reward, next_obs, done))
obs = next_obs
if done:
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This experience collection is working only on a single agent/single thread?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is being run in multiple different processes. Its being passed to the ActorNode which is running it in its own infinite loop



class MyTrainer(DistributedTrainer):
def __init__(self, agent, train_steps, batch_size):
super(MyTrainer, self).__init__(agent)
self.train_steps = train_steps
self.batch_size = batch_size

def train(self, parameter_server_rref, experience_server_rref):
i = 0
while i < self.train_steps:
batch = experience_server_rref.rpc_sync().sample(self.batch_size)
if batch is None:
continue
self.agent.update_params(batch, 1)
parameter_server_rref.rpc_sync().store_weights(self.agent.get_weights())
print(f"Trainer: {i + 1} / {self.train_steps} steps completed")
i += 1


mp.set_start_method("fork")

master = Master(world_size=6, address="localhost", port=29500)
env = gym.make("Pendulum-v0")
agent = DDPG("mlp", env)
parameter_server = ParameterServer(
"param-0", master, WeightHolder(agent.get_weights()), rank=1
)
buffer = ReplayBuffer(BUFFER_SIZE)
experience_server = ExperienceServer("experience-0", master, buffer, rank=2)
trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE)
learner = LearnerNode(
"learner-0", master, parameter_server, experience_server, trainer, rank=3
)
actors = [
ActorNode(
f"actor-{i}",
master,
parameter_server,
experience_server,
learner,
agent,
collect_experience,
rank=i + 4,
)
for i in range(N_ACTORS)
]
17 changes: 9 additions & 8 deletions genrl/agents/deep/base/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _reshape_batch(self, batch: List):
"""
return [*batch]

def sample_from_buffer(self, beta: float = None):
def sample_from_buffer(self, beta: float = None, batch = None):
"""Samples experiences from the buffer and converts them into usable formats

Args:
Expand All @@ -89,11 +89,12 @@ def sample_from_buffer(self, beta: float = None):
Returns:
batch (:obj:`list`): Replay experiences sampled from the buffer
"""
# Samples from the buffer
if beta is not None:
batch = self.replay_buffer.sample(self.batch_size, beta=beta)
else:
batch = self.replay_buffer.sample(self.batch_size)
if batch is None:
# Samples from the buffer
if beta is not None:
batch = self.replay_buffer.sample(self.batch_size, beta=beta)
else:
batch = self.replay_buffer.sample(self.batch_size)

states, actions, rewards, next_states, dones = self._reshape_batch(batch)

Expand All @@ -106,7 +107,7 @@ def sample_from_buffer(self, beta: float = None):
*[states, actions, rewards, next_states, dones, indices, weights]
)
else:
raise NotImplementedError
batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones])
return batch

def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor:
Expand Down Expand Up @@ -277,4 +278,4 @@ def load_weights(self, weights) -> None:
Args:
weights (:obj:`dict`): Dictionary of different neural net weights
"""
self.ac.load_state_dict(weights["weights"])
self.ac.load_state_dict(weights)
7 changes: 5 additions & 2 deletions genrl/agents/deep/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ def _create_model(self) -> None:
self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy)
self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value)

def update_params(self, update_interval: int) -> None:
def update_params(self, batch, update_interval: int) -> None:
"""Update parameters of the model

Args:
update_interval (int): Interval between successive updates of the target model
"""
for timestep in range(update_interval):
batch = self.sample_from_buffer()
batch = self.sample_from_buffer(batch=batch)

value_loss = self.get_q_loss(batch)
self.logs["value_loss"].append(value_loss.item())
Expand Down Expand Up @@ -123,6 +123,9 @@ def get_hyperparams(self) -> Dict[str, Any]:
}
return hyperparams

def get_weights(self):
return self.ac.state_dict()

def get_logging_params(self) -> Dict[str, Any]:
"""Gets relevant parameters for logging

Expand Down
3 changes: 3 additions & 0 deletions genrl/core/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def sample(
:returns: (Tuple composing of `state`, `action`, `reward`,
`next_state` and `done`)
"""
if batch_size > len(self.memory):
return None

batch = random.sample(self.memory, batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch))
return [
Expand Down
5 changes: 5 additions & 0 deletions genrl/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from genrl.distributed.core import Master, Node
from genrl.distributed.parameter_server import ParameterServer, WeightHolder
from genrl.distributed.experience_server import ExperienceServer
from genrl.distributed.actor import ActorNode
from genrl.distributed.learner import LearnerNode
57 changes: 57 additions & 0 deletions genrl/distributed/actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from genrl.distributed.core import Node
from genrl.distributed.core import get_rref, store_rref
import torch.distributed.rpc as rpc


class ActorNode(Node):
def __init__(
self,
name,
master,
parameter_server,
experience_server,
learner,
agent,
collect_experience,
rank=None,
):
super(ActorNode, self).__init__(name, master, rank)
self.parameter_server = parameter_server
self.experience_server = experience_server
self.init_proc(
target=self.act,
kwargs=dict(
parameter_server_name=parameter_server.name,
experience_server_name=experience_server.name,
learner_name=learner.name,
agent=agent,
collect_experience=collect_experience,
),
)
self.start_proc()

@staticmethod
def act(
name,
world_size,
rank,
parameter_server_name,
experience_server_name,
learner_name,
agent,
collect_experience,
**kwargs,
):
rpc.init_rpc(name=name, world_size=world_size, rank=rank)
print(f"{name}: RPC Initialised")
rref = rpc.RRef(agent)
store_rref(name, rref)
parameter_server_rref = get_rref(parameter_server_name)
experience_server_rref = get_rref(experience_server_name)
learner_rref = get_rref(learner_name)
print(f"{name}: Begining experience collection")
while not learner_rref.rpc_sync().is_done():
agent.load_weights(parameter_server_rref.rpc_sync().get_weights())
Copy link
Member

@Sharad24 Sharad24 Oct 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to assign the agent in the constructor itself?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assign agent weights? They will need to be updated in the loop right?

collect_experience(agent, experience_server_rref)

rpc.shutdown()
148 changes: 148 additions & 0 deletions genrl/distributed/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import torch.distributed.rpc as rpc

import threading

from abc import ABC, abstractmethod
import torch.multiprocessing as mp
import os
import time

_rref_reg = {}
_global_lock = threading.Lock()


def _get_rref(idx):
global _rref_reg
with _global_lock:
if idx in _rref_reg.keys():
return _rref_reg[idx]
else:
return None


def _store_rref(idx, rref):
global _rref_reg
with _global_lock:
if idx in _rref_reg.keys():
raise Warning(
f"Re-assigning RRef for key: {idx}. Make sure you are not using duplicate names for nodes"
)
_rref_reg[idx] = rref


def get_rref(idx):
rref = rpc.rpc_sync("master", _get_rref, args=(idx,))
while rref is None:
time.sleep(0.5)
rref = rpc.rpc_sync("master", _get_rref, args=(idx,))
return rref


def store_rref(idx, rref):
rpc.rpc_sync("master", _store_rref, args=(idx, rref))


def set_environ(address, port):
os.environ["MASTER_ADDR"] = str(address)
os.environ["MASTER_PORT"] = str(port)


class Node:
def __init__(self, name, master, rank):
self._name = name
self.master = master
if rank is None:
self._rank = master.node_count
elif rank >= 0 and rank < master.world_size:
self._rank = rank
elif rank >= master.world_size:
raise ValueError("Specified rank greater than allowed by world size")
else:
raise ValueError("Invalid value of rank")
self.p = None

def __del__(self):
if self.p is None:
raise RuntimeWarning(
"Removing node when process was not initialised properly"
)
else:
self.p.join()

@staticmethod
def _target_wrapper(target, **kwargs):
pid = os.getpid()
print(f"Starting {kwargs['name']} with pid {pid}")
set_environ(kwargs["master_address"], kwargs["master_port"])
target(**kwargs)
print(f"Shutdown {kwargs['name']} with pid {pid}")

def init_proc(self, target, kwargs):
kwargs.update(
dict(
name=self.name,
master_address=self.master.address,
master_port=self.master.port,
world_size=self.master.world_size,
rank=self.rank,
)
)
self.p = mp.Process(target=self._target_wrapper, args=(target,), kwargs=kwargs)

def start_proc(self):
if self.p is None:
raise RuntimeError("Trying to start uninitialised process")
self.p.start()

@property
def name(self):
return self._name

@property
def rref(self):
return get_rref(self.name)

@property
def rank(self):
return self._rank


def _run_master(world_size):
print(f"Starting master at {os.getpid()}")
rpc.init_rpc("master", rank=0, world_size=world_size)
rpc.shutdown()


class Master:
def __init__(self, world_size, address="localhost", port=29501):
set_environ(address, port)
self._world_size = world_size
self._address = address
self._port = port
self._node_counter = 0
self.p = mp.Process(target=_run_master, args=(world_size,))
self.p.start()

def __del__(self):
if self.p is None:
raise RuntimeWarning(
"Shutting down master when it was not initialised properly"
)
else:
self.p.join()

@property
def world_size(self):
return self._world_size

@property
def address(self):
return self._address

@property
def port(self):
return self._port

@property
def node_count(self):
return self._node_counter
23 changes: 23 additions & 0 deletions genrl/distributed/experience_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from genrl.distributed import Node
from genrl.distributed.core import store_rref

import torch.distributed.rpc as rpc


class ExperienceServer(Node):
def __init__(self, name, master, buffer, rank=None):
super(ExperienceServer, self).__init__(name, master, rank)
self.init_proc(
target=self.run_paramater_server,
kwargs=dict(buffer=buffer),
)
self.start_proc()

@staticmethod
def run_paramater_server(name, world_size, rank, buffer, **kwargs):
rpc.init_rpc(name=name, world_size=world_size, rank=rank)
print(f"{name}: Initialised RPC")
rref = rpc.RRef(buffer)
store_rref(name, rref)
print(f"{name}: Serving experience buffer")
rpc.shutdown()
Loading