From 2567fe5ecc3eba9b170507d8da575a75a7e36f85 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Fri, 28 Aug 2020 01:17:15 +0530 Subject: [PATCH 01/27] added demo notebooks --- examples/Bandit_demo.ipynb | 825 +++++++++++++++++++++++++++++++++++++ examples/DQN_demo.ipynb | 369 +++++++++++++++++ 2 files changed, 1194 insertions(+) create mode 100644 examples/Bandit_demo.ipynb create mode 100644 examples/DQN_demo.ipynb diff --git a/examples/Bandit_demo.ipynb b/examples/Bandit_demo.ipynb new file mode 100644 index 00000000..c30181b5 --- /dev/null +++ b/examples/Bandit_demo.ipynb @@ -0,0 +1,825 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "GenRL Bandit Demo", + "provenance": [], + "collapsed_sections": [], + "authorship_tag": "ABX9TyM/n2lbgIBtTIt8joCJyzNp", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YWkMa_yBS2hT", + "colab_type": "text" + }, + "source": [ + "# Example of using Bandits from GenRL" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2sXT4gp-O4za", + "colab_type": "text" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Pbkxqa7t2UQE", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "87439d9f-15f4-4340-df18-1a55578c8458" + }, + "source": [ + "!git clone https://github.com/SforAiDl/genrl.git\n", + "!pip install -e genrl" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Cloning into 'genrl'...\n", + "remote: Enumerating objects: 119, done.\u001b[K\n", + "remote: Counting objects: 100% (119/119), done.\u001b[K\n", + "remote: Compressing objects: 100% (107/107), done.\u001b[K\n", + "remote: Total 7155 (delta 35), reused 30 (delta 8), pack-reused 7036\u001b[K\n", + "Receiving objects: 100% (7155/7155), 7.59 MiB | 14.18 MiB/s, done.\n", + "Resolving deltas: 100% (4325/4325), done.\n", + "Obtaining file:///content/genrl\n", + "Requirement already satisfied: atari-py==0.2.6 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (0.2.6)\n", + "Collecting box2d-py==2.3.8\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/06/bd/6cdc3fd994b0649dcf5d9bad85bd9e26172308bbe9a421bfc6fdbf5081a6/box2d_py-2.3.8-cp36-cp36m-manylinux1_x86_64.whl (448kB)\n", + "\u001b[K |████████████████████████████████| 450kB 5.9MB/s \n", + "\u001b[?25hCollecting certifi==2019.11.28\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/b9/63/df50cac98ea0d5b006c55a399c3bf1db9da7b5a24de7890bc9cfd5dd9e99/certifi-2019.11.28-py2.py3-none-any.whl (156kB)\n", + "\u001b[K |████████████████████████████████| 163kB 9.2MB/s \n", + "\u001b[?25hRequirement already satisfied: cloudpickle==1.3.0 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.3.0)\n", + "Collecting future==0.18.2\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)\n", + "\u001b[K |████████████████████████████████| 829kB 14.3MB/s \n", + "\u001b[?25hCollecting gym==0.17.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/77/48/c43b8a72b916cc70896aa431b0fc00d1481ae34e28dc55e2144f4c77916b/gym-0.17.1.tar.gz (1.6MB)\n", + "\u001b[K |████████████████████████████████| 1.6MB 33.0MB/s \n", + "\u001b[?25hCollecting numpy==1.18.2\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/07/08/a549ba8b061005bb629b76adc000f3caaaf881028b963c2e18f811c6edc1/numpy-1.18.2-cp36-cp36m-manylinux1_x86_64.whl (20.2MB)\n", + "\u001b[K |████████████████████████████████| 20.2MB 1.5MB/s \n", + "\u001b[?25hCollecting opencv-python==4.2.0.34\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/72/c2/e9cf54ae5b1102020ef895866a67cb2e1aef72f16dd1fde5b5fb1495ad9c/opencv_python-4.2.0.34-cp36-cp36m-manylinux1_x86_64.whl (28.2MB)\n", + "\u001b[K |████████████████████████████████| 28.2MB 91kB/s \n", + "\u001b[?25hCollecting pandas==1.0.4\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/8e/86/c14387d6813ebadb7bf61b9ad270ffff111c8b587e4d266e07de774e385e/pandas-1.0.4-cp36-cp36m-manylinux1_x86_64.whl (10.1MB)\n", + "\u001b[K |████████████████████████████████| 10.1MB 48.5MB/s \n", + "\u001b[?25hCollecting Pillow==7.1.0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/ef/73/468ff799fc61607b4f37698d02e4a6699b96f8c0abfa8d973e717bafcba4/Pillow-7.1.0-cp36-cp36m-manylinux1_x86_64.whl (2.1MB)\n", + "\u001b[K |████████████████████████████████| 2.1MB 42.9MB/s \n", + "\u001b[?25hRequirement already satisfied: pyglet==1.5.0 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.5.0)\n", + "Requirement already satisfied: scipy==1.4.1 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.4.1)\n", + "Collecting six==1.14.0\n", + " Downloading https://files.pythonhosted.org/packages/65/eb/1f97cb97bfc2390a276969c6fae16075da282f5058082d4cb10c6c5c1dba/six-1.14.0-py2.py3-none-any.whl\n", + "Collecting matplotlib==3.2.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/93/4b/52da6b1523d5139d04e02d9e26ceda6146b48f2a4e5d2abfdf1c7bac8c40/matplotlib-3.2.1-cp36-cp36m-manylinux1_x86_64.whl (12.4MB)\n", + "\u001b[K |████████████████████████████████| 12.4MB 17.6MB/s \n", + "\u001b[?25hCollecting pytest==5.4.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/c7/e2/c19c667f42f72716a7d03e8dd4d6f63f47d39feadd44cc1ee7ca3089862c/pytest-5.4.1-py3-none-any.whl (246kB)\n", + "\u001b[K |████████████████████████████████| 256kB 46.0MB/s \n", + "\u001b[?25hCollecting torch==1.4.0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/24/19/4804aea17cd136f1705a5e98a00618cb8f6ccc375ad8bfa437408e09d058/torch-1.4.0-cp36-cp36m-manylinux1_x86_64.whl (753.4MB)\n", + "\u001b[K |████████████████████████████████| 753.4MB 18kB/s \n", + "\u001b[?25hCollecting torchvision==0.5.0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/7e/90/6141bf41f5655c78e24f40f710fdd4f8a8aff6c8b7c6f0328240f649bdbe/torchvision-0.5.0-cp36-cp36m-manylinux1_x86_64.whl (4.0MB)\n", + "\u001b[K |████████████████████████████████| 4.0MB 39.4MB/s \n", + "\u001b[?25hCollecting tensorflow-tensorboard==1.5.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/cc/fa/91c06952517b4f1bc075545b062a4112e30cebe558a6b962816cb33efa27/tensorflow_tensorboard-1.5.1-py3-none-any.whl (3.0MB)\n", + "\u001b[K |████████████████████████████████| 3.0MB 38.6MB/s \n", + "\u001b[?25hCollecting tensorboard==1.15.0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/1e/e9/d3d747a97f7188f48aa5eda486907f3b345cd409f0a0850468ba867db246/tensorboard-1.15.0-py3-none-any.whl (3.8MB)\n", + "\u001b[K |████████████████████████████████| 3.8MB 44.8MB/s \n", + "\u001b[?25hCollecting pre-commit==2.4.0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/fc/70/b162b8e15689427006853a245b45040435534a568cc773be098bfaebd36d/pre_commit-2.4.0-py2.py3-none-any.whl (171kB)\n", + "\u001b[K |████████████████████████████████| 174kB 50.0MB/s \n", + "\u001b[?25hCollecting importlib-resources==1.0.1\n", + " Downloading https://files.pythonhosted.org/packages/81/a3/466b268701207e00b4440803be132e892bd9fc74f1fe786d7e33146ad2c7/importlib_resources-1.0.1-py2.py3-none-any.whl\n", + "Collecting setuptools==41.0.0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/c8/b0/cc6b7ba28d5fb790cf0d5946df849233e32b8872b6baca10c9e002ff5b41/setuptools-41.0.0-py2.py3-none-any.whl (575kB)\n", + "\u001b[K |████████████████████████████████| 583kB 42.7MB/s \n", + "\u001b[?25hRequirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas==1.0.4->genrl==0.0.1) (2018.9)\n", + "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas==1.0.4->genrl==0.0.1) (2.8.1)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib==3.2.1->genrl==0.0.1) (1.2.0)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib==3.2.1->genrl==0.0.1) (0.10.0)\n", + "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib==3.2.1->genrl==0.0.1) (2.4.7)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (20.4)\n", + "Collecting pluggy<1.0,>=0.12\n", + " Downloading https://files.pythonhosted.org/packages/a0/28/85c7aa31b80d150b772fbe4a229487bc6644da9ccb7e427dd8cc60cb8a62/pluggy-0.13.1-py2.py3-none-any.whl\n", + "Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (0.2.5)\n", + "Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (1.9.0)\n", + "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (20.1.0)\n", + "Requirement already satisfied: more-itertools>=4.0.0 in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (8.4.0)\n", + "Requirement already satisfied: importlib-metadata>=0.12; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (1.7.0)\n", + "Collecting bleach==1.5.0\n", + " Downloading https://files.pythonhosted.org/packages/33/70/86c5fec937ea4964184d4d6c4f0b9551564f821e1c3575907639036d9b90/bleach-1.5.0-py2.py3-none-any.whl\n", + "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorflow-tensorboard==1.5.1->genrl==0.0.1) (0.35.1)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow-tensorboard==1.5.1->genrl==0.0.1) (3.2.2)\n", + "Requirement already satisfied: protobuf>=3.4.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-tensorboard==1.5.1->genrl==0.0.1) (3.12.4)\n", + "Collecting html5lib==0.9999999\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/ae/ae/bcb60402c60932b32dfaf19bb53870b29eda2cd17551ba5639219fb5ebf9/html5lib-0.9999999.tar.gz (889kB)\n", + "\u001b[K |████████████████████████████████| 890kB 39.4MB/s \n", + "\u001b[?25hRequirement already satisfied: werkzeug>=0.11.10 in /usr/local/lib/python3.6/dist-packages (from tensorflow-tensorboard==1.5.1->genrl==0.0.1) (1.0.1)\n", + "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tensorboard==1.15.0->genrl==0.0.1) (0.8.1)\n", + "Requirement already satisfied: grpcio>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard==1.15.0->genrl==0.0.1) (1.31.0)\n", + "Collecting cfgv>=2.0.0\n", + " Downloading https://files.pythonhosted.org/packages/45/cd/3878c9248e59e5e2ebd0dc741ab984b18d86e7283ae9b127b05fc287d239/cfgv-3.2.0-py2.py3-none-any.whl\n", + "Collecting virtualenv>=20.0.8\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/12/51/36c685ff2c1b2f7b4b5db29f3153159102ae0e0adaff3a26fd1448232e06/virtualenv-20.0.31-py2.py3-none-any.whl (4.9MB)\n", + "\u001b[K |████████████████████████████████| 4.9MB 41.4MB/s \n", + "\u001b[?25hCollecting pyyaml>=5.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n", + "\u001b[K |████████████████████████████████| 276kB 36.7MB/s \n", + "\u001b[?25hCollecting nodeenv>=0.11.1\n", + " Downloading https://files.pythonhosted.org/packages/ae/d0/efdf54539948315cc76e5a66b709212963101d002822c3b54369dbf9b5e0/nodeenv-1.5.0-py2.py3-none-any.whl\n", + "Collecting identify>=1.0.0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/fd/e0/593fed766c1ddbf7dfb3bdbc997af3c5dcbf7862392f4ee5fef744128622/identify-1.4.29-py2.py3-none-any.whl (97kB)\n", + "\u001b[K |████████████████████████████████| 102kB 11.4MB/s \n", + "\u001b[?25hRequirement already satisfied: toml in /usr/local/lib/python3.6/dist-packages (from pre-commit==2.4.0->genrl==0.0.1) (0.10.1)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.12; python_version < \"3.8\"->pytest==5.4.1->genrl==0.0.1) (3.1.0)\n", + "Collecting appdirs<2,>=1.4.3\n", + " Downloading https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl\n", + "Requirement already satisfied: filelock<4,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from virtualenv>=20.0.8->pre-commit==2.4.0->genrl==0.0.1) (3.0.12)\n", + "Collecting distlib<1,>=0.3.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/f5/0a/490fa011d699bb5a5f3a0cf57de82237f52a6db9d40f33c53b2736c9a1f9/distlib-0.3.1-py2.py3-none-any.whl (335kB)\n", + "\u001b[K |████████████████████████████████| 337kB 42.3MB/s \n", + "\u001b[?25hBuilding wheels for collected packages: future, gym, html5lib, pyyaml\n", + " Building wheel for future (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for future: filename=future-0.18.2-cp36-none-any.whl size=491057 sha256=8cf52b30de5fee8e76119c05431d1a52eef7ae39a131d23fa70b3d2bfeb1c13d\n", + " Stored in directory: /root/.cache/pip/wheels/8b/99/a0/81daf51dcd359a9377b110a8a886b3895921802d2fc1b2397e\n", + " Building wheel for gym (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for gym: filename=gym-0.17.1-cp36-none-any.whl size=1648710 sha256=f5ef583b9328385c7c344c372ca3ae1a085b8aa56557482d98a9905383612668\n", + " Stored in directory: /root/.cache/pip/wheels/c0/84/61/523b92d88787ae29689b3cc08cf445d8d8186d7fbe1acbf87b\n", + " Building wheel for html5lib (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for html5lib: filename=html5lib-0.9999999-cp36-none-any.whl size=107220 sha256=b80489e513128c0bb22be245f1b98637449c9272470ad870e93375e5acdffdb7\n", + " Stored in directory: /root/.cache/pip/wheels/50/ae/f9/d2b189788efcf61d1ee0e36045476735c838898eef1cad6e29\n", + " Building wheel for pyyaml (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44619 sha256=77b47e0ba1f5f22e5bd75345640c374869a74cae41ddde00ddb8521ab7db4f20\n", + " Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd\n", + "Successfully built future gym html5lib pyyaml\n", + "\u001b[31mERROR: xarray 0.15.1 has requirement setuptools>=41.2, but you'll have setuptools 41.0.0 which is incompatible.\u001b[0m\n", + "\u001b[31mERROR: tensorflow 2.3.0 has requirement tensorboard<3,>=2.3.0, but you'll have tensorboard 1.15.0 which is incompatible.\u001b[0m\n", + "\u001b[31mERROR: google-colab 1.0.0 has requirement six~=1.15.0, but you'll have six 1.14.0 which is incompatible.\u001b[0m\n", + "\u001b[31mERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.\u001b[0m\n", + "\u001b[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.\u001b[0m\n", + "Installing collected packages: box2d-py, certifi, future, numpy, six, gym, opencv-python, pandas, Pillow, matplotlib, pluggy, pytest, torch, torchvision, html5lib, bleach, tensorflow-tensorboard, setuptools, tensorboard, cfgv, appdirs, importlib-resources, distlib, virtualenv, pyyaml, nodeenv, identify, pre-commit, genrl\n", + " Found existing installation: certifi 2020.6.20\n", + " Uninstalling certifi-2020.6.20:\n", + " Successfully uninstalled certifi-2020.6.20\n", + " Found existing installation: future 0.16.0\n", + " Uninstalling future-0.16.0:\n", + " Successfully uninstalled future-0.16.0\n", + " Found existing installation: numpy 1.18.5\n", + " Uninstalling numpy-1.18.5:\n", + " Successfully uninstalled numpy-1.18.5\n", + " Found existing installation: six 1.15.0\n", + " Uninstalling six-1.15.0:\n", + " Successfully uninstalled six-1.15.0\n", + " Found existing installation: gym 0.17.2\n", + " Uninstalling gym-0.17.2:\n", + " Successfully uninstalled gym-0.17.2\n", + " Found existing installation: opencv-python 4.1.2.30\n", + " Uninstalling opencv-python-4.1.2.30:\n", + " Successfully uninstalled opencv-python-4.1.2.30\n", + " Found existing installation: pandas 1.0.5\n", + " Uninstalling pandas-1.0.5:\n", + " Successfully uninstalled pandas-1.0.5\n", + " Found existing installation: Pillow 7.0.0\n", + " Uninstalling Pillow-7.0.0:\n", + " Successfully uninstalled Pillow-7.0.0\n", + " Found existing installation: matplotlib 3.2.2\n", + " Uninstalling matplotlib-3.2.2:\n", + " Successfully uninstalled matplotlib-3.2.2\n", + " Found existing installation: pluggy 0.7.1\n", + " Uninstalling pluggy-0.7.1:\n", + " Successfully uninstalled pluggy-0.7.1\n", + " Found existing installation: pytest 3.6.4\n", + " Uninstalling pytest-3.6.4:\n", + " Successfully uninstalled pytest-3.6.4\n", + " Found existing installation: torch 1.6.0+cu101\n", + " Uninstalling torch-1.6.0+cu101:\n", + " Successfully uninstalled torch-1.6.0+cu101\n", + " Found existing installation: torchvision 0.7.0+cu101\n", + " Uninstalling torchvision-0.7.0+cu101:\n", + " Successfully uninstalled torchvision-0.7.0+cu101\n", + " Found existing installation: html5lib 1.0.1\n", + " Uninstalling html5lib-1.0.1:\n", + " Successfully uninstalled html5lib-1.0.1\n", + " Found existing installation: bleach 3.1.5\n", + " Uninstalling bleach-3.1.5:\n", + " Successfully uninstalled bleach-3.1.5\n", + " Found existing installation: setuptools 49.6.0\n", + " Uninstalling setuptools-49.6.0:\n", + " Successfully uninstalled setuptools-49.6.0\n", + " Found existing installation: tensorboard 2.3.0\n", + " Uninstalling tensorboard-2.3.0:\n", + " Successfully uninstalled tensorboard-2.3.0\n", + " Found existing installation: PyYAML 3.13\n", + " Uninstalling PyYAML-3.13:\n", + " Successfully uninstalled PyYAML-3.13\n", + " Running setup.py develop for genrl\n", + "Successfully installed Pillow-7.1.0 appdirs-1.4.4 bleach-1.5.0 box2d-py-2.3.8 certifi-2019.11.28 cfgv-3.2.0 distlib-0.3.1 future-0.18.2 genrl gym-0.17.1 html5lib-0.9999999 identify-1.4.29 importlib-resources-1.0.1 matplotlib-3.2.1 nodeenv-1.5.0 numpy-1.18.2 opencv-python-4.2.0.34 pandas-1.0.4 pluggy-0.13.1 pre-commit-2.4.0 pytest-5.4.1 pyyaml-5.3.1 setuptools-41.0.0 six-1.14.0 tensorboard-1.15.0 tensorflow-tensorboard-1.5.1 torch-1.4.0 torchvision-0.5.0 virtualenv-20.0.31\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "PIL", + "matplotlib", + "mpl_toolkits", + "numpy", + "pandas", + "pkg_resources", + "six" + ] + } + } + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "kMa_HDBT3APh", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import matplotlib.pyplot as plt\n", + "import torch" + ], + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IhfRPmuHPWUp", + "colab_type": "text" + }, + "source": [ + "## Epsilon Greedy Policy on a Bernoulli Bandit" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "rVSRR6bG3Dr3", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 744 + }, + "outputId": "6fe7d6c0-61bd-4560-c8ad-5c712bcc42ab" + }, + "source": [ + "from genrl.agents import EpsGreedyMABAgent, BernoulliMAB\n", + "from genrl.trainers import MABTrainer\n", + "\n", + "bandit = BernoulliMAB(arms=50, context_type=\"int\") \n", + "agent = EpsGreedyMABAgent(bandit, eps=0.1)\n", + "trainer = MABTrainer(agent, bandit) \n", + "results = trainer.train(2000)\n", + "\n", + "plt.plot(results[\"cumulative_regrets\"])" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\n", + "Started at 27-08-20 18:48:38\n", + "Training EpsGreedyMABAgent on BernoulliMAB for 2000 timesteps\n", + "timestep regret/regret reward/reward regret/cumulative_regret reward/cumulative_reward regret/regret_moving_avg reward/reward_moving_avg \n", + "100 0 1 14 86 0.14 0.86 \n", + "200 0 1 19 181 0.05 0.95 \n", + "300 0 1 29 271 0.1 0.9 \n", + "400 0 1 34 366 0.05 0.95 \n", + "500 0 1 44 456 0.1 0.9 \n", + "600 0 1 52 548 0.08 0.92 \n", + "700 0 1 55 645 0.03 0.97 \n", + "800 0 1 58 742 0.03 0.97 \n", + "900 0 1 62 838 0.04 0.96 \n", + "1000 0 1 67 933 0.05 0.95 \n", + "1100 0 1 69 1031 0.02 0.98 \n", + "1200 0 1 83 1117 0.14 0.86 \n", + "1300 0 1 90 1210 0.07 0.93 \n", + "1400 0 1 95 1305 0.05 0.95 \n", + "1500 0 1 100 1400 0.05 0.95 \n", + "1600 0 1 110 1490 0.1 0.9 \n", + "1700 0 1 114 1586 0.04 0.96 \n", + "1800 0 1 120 1680 0.06 0.94 \n", + "1900 0 1 123 1777 0.03 0.97 \n", + "2000 0 1 129 1871 0.06 0.94 \n", + "Training completed in 0 seconds\n", + "Final Regret Moving Average: 0.06 | Final Reward Moving Average: 0.94\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[]" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 2 + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3deXxU9f3v8dcnIWEJgQCBENawqpHdiKK44oIWRev+81E3Wn56tdrF69L+rL23i1pbrXaxF2vrLq4t2krdN1SUVRbZQwKBJCQBQiBkne/9Yw4xQAIks5zJ5P18PPLIzPecyXlzZvLm5MyZc8w5h4iIxJcEvwOIiEj4qdxFROKQyl1EJA6p3EVE4pDKXUQkDnXwOwBAenq6y8rK8juGiEibsmjRolLnXO+mpsVEuWdlZbFw4UK/Y4iItClmlt/cNO2WERGJQyp3EZE4pHIXEYlDKncRkTikchcRiUMqdxGROKRyFxGJQyp3EREfbCjZzUPvrOXjtSUR+fkqdxERHzz+cS6PvreOz3PLIvLzY+ITqiIi7cHdry3j1UVbAKgNBBjdvzt3Tj06IstSuYuIRMHCvO288OVmRmZ0ZcoxGQCcMjw9YstTuYuIRMHtL38FwMXjB3DT6cMivjztcxcRibD7564mr6ySqyYOikqxg8pdRCSiCnZU8pePNgAwY3JW1JarchcRiaBfv7kKgJvPGMbwPqlRW672uYuIhOi9VcXcN3c1AecOmrZ1516O7pvKj88+KqqZVO4iIkdgQd52tu+paXLa7C83sXl7JWdnZxw0LTuzG9+e0J+EBIt0xP2o3EVEDqNgRyWX/eXzQ84zflAaf/yvCVFKdHgqdxGRZvzmP6v5bEMZlTV1APz64tGMHdi9yXkH9OgSzWiHddhyN7O/AdOAbc65Ud7Yg8AFQA2wAbjeObfTm3Y3MAOoB251zr0VoewiImFVH3As3byT2voAAM/Mz6d75ySG9u7KiD6pXDA2k9ROST6nPDJHsuX+JPBH4OlGY+8Adzvn6szsAeBu4E4zywauBI4F+gHvmtlI51x9eGOLiITfv5cXcusLS/Ybu23KCL57ylCfErXeYcvdOfexmWUdMPZ2o7vzgUu929OB2c65amCjma0HJgKH3lklIhIFLy3czL+XFTY7ffOOSgCemTGRxAQj0Yxxg9KiFS+swrHP/QbgRe92f4Jlv0+BN3YQM5sJzAQYNGhQGGKIiAQ559hQspuauv0PTXz841y2VVSTlZ7S5ONSOyVxRc5AThnROxoxIyqkcjeznwJ1wHMtfaxzbhYwCyAnJ+fgg0NFRFrpzeVF3Pz84ianfefEwfziolFRThR9rS53M7uO4ButU5xrOHJ/CzCw0WwDvDERkah56rM8AP589QQS7Jvjy83ghCE9fUoVXa0qdzObCtwBnOacq2w06XXgeTN7iOAbqiOAL0NOKSLSApu2V5KUaJw/OtPvKL45kkMhXwBOB9LNrAC4l+DRMR2Bdyz4v+J859yNzrmVZvYS8DXB3TU360gZEYmmQMBRuruamae2vSNcwulIjpa5qonhJw4x/6+AX4USSkSktT5aW0JdwNEntaPfUXyls0KKSNyorQ/wyuICACYM7uFzGn/p9AMiEhecc5z10Efkl1UyMasnYwa0zePTw0XlLiJt0iuLCthYurvhfl3AkV8WPDPjbVNG+JgsNqjcRaTNqaqt5/aXv8IMEhsd6tglOZEZk4cwqn/TJ/dqT1TuItLmlFRUA/DAJWO4PGfgYeZun1TuIhJTNm+vZO6KQpq4qFGDwvIqgHZ/RMyhqNxFJKb8+cP1vPDl5sPO1ykpgREZ0bsmaVujcheRmFFXH+CFLzeTndmNV26adMh5OyQkkNxBR3M3R+UuIr6rqq1n7opCSiuC1yg9OjOVLsmqp1Bo7YmI795btY0fvvhVw/3/mqjTgIdK5S4ivioqr2o4Pe+7PzqVnikd6ZmS7HOqtk/lLiK+WZS/nXdXbQNg6rF9Gd5Hb5CGi8pdRHyxs7KGy/7yOQEHHRKM+7492u9IcUXlLiK+ePyTXAIO7pmWzbQxmfTQrpiwUrmLSNQsLyhnV1UtAB+vLQXgonH96NVVH0YKN5W7iETFhpLdXPDHefuNXTy+v4o9QlTuIhJx/1hSwBPzNgLwq4tHMcJ74/SYTL2BGikqdxGJqPqA4743V7O3tp4pR/fh4vH99QGlKNBnd0Ukop6dn8+2imrOPiaDJ647XsUeJSp3EYmY3dV13Pv6SgD+Z1q2z2naF5W7iETM4vwdAEwa2kufOo0ylbuIRMw276Ia91+iDyhF22HL3cz+ZmbbzGxFo7GeZvaOma3zvvfwxs3MHjWz9Wa2zMwmRDK8iMSupZt38vwX+QD0Se3kc5r250i23J8Eph4wdhfwnnNuBPCedx/gPGCE9zUTeCw8MUUkVjnnKK+sPejrr5/k8lVBOaeO7E3n5ES/Y7Y7h33b2jn3sZllHTA8HTjdu/0U8CFwpzf+tHPOAfPNLM3MMp1zheEKLCKx5WdzVvLM/Pwmpx2f1YOnb5gY5UQCrT/OPaNRYRcBGd7t/kDj62MVeGMHlbuZzSS4dc+gQTp3s0isqw84/t/HGyjfW7vf+HurihnWO4WrTxh80GNOGNozWvHkACEfcOqcc2Z2iEvZNvu4WcAsgJycnBY/XkQizzW6SvXXW3fxm/+sISnRSDDbb76bJw7nhslDoh1PDqG15V68b3eLmWUC27zxLcDARvMN8MZEpI1xzjH195+wprhiv/FXbzqJMQPSfEolR6q15f46cC1wv/d9TqPxW8xsNnACUK797SJt0+7qOtYUV3DayN6MHxQs8+6dkxjVr7vPyeRIHLbczewFgm+epptZAXAvwVJ/ycxmAPnA5d7sbwLnA+uBSuD6CGQWkSjYd4z6xeP7c9H4/j6nkZY6kqNlrmpm0pQm5nXAzaGGEhH/VNbUMe0P8ygqrwKgTzedkrct0hl8RNq5vNI9zFtf2nB/W0U1uSV7OOuYPhyT2Y3jBvfwMZ20lspdpJ27b+4q3lpZvN9YgsEdU49mZIbOt95WqdxF2rH8sj28tbKYSUN78chV4xrGOyUl0q1Tko/JJFQqd5F2aMeeGuatL2XxpuBZG88b3Vfnf4kzKneRdujR99fx90/zAOiUlMAlEwb4G0jCTuUuEod2VdUy8+mF7Npb1+T0gh2VDElP4fFrcujRJYmUjqqCeKNnVCQOLd20k/m528kZ3IO0LgdfJKNfWmfOyc5geJ+uPqSTaFC5i8ShTzcED2186PJxDOrVxec04geVu0gceO6LfN74amvD/fyyShITTMXejukyeyJx4IUvN7GmqIKAg4CDgT278P0zh/sdS3ykLXeROFC8q5pzsvvywKVj/I4iMUJb7iJtXF19gLLd1WToHDDSiLbcRdqgQMCRv70S5xzb99QQcNC7mz6EJN9QuYu0QX94fz0Pv7t2v7EBaZ19SiOxSOUu0oa8triAzzeUsSBvO71TO/I/3zoGCJ4L5pQR6T6nk1iichdpQx74z2oqqupI65zE9LH9mD5OF9GQpqncRdqI0t3VFO+qZsbkIdwzLdvvOBLjVO4iMW5NUQWvLNrccNm7Y/t18zmRtAUqd5EYVB9wBJwD4K+f5PLyogJSkhPJ6NaRk4dr37ocnspdJMbkl+3h3N9/TFVtoGFs7IDuzLllso+ppK1RuYvEmDlLt1JVG+C6k7JI7xo8o+OkYb18TiVtTUjlbmY/BL4LOGA5cD2QCcwGegGLgO8452pCzCnSbizI2w7Aj84ZqUvdSau1+vQDZtYfuBXIcc6NAhKBK4EHgIedc8OBHcCMcAQVaS9KKqqZcnQfFbuEJNRzy3QAOptZB6ALUAicCbziTX8KuCjEZYi0GwU7KlldVEFGd51KQELT6nJ3zm0BfgtsIljq5QR3w+x0zu27tlcB0OSnLMxsppktNLOFJSUlrY0hElcefW8dAMdk6nBHCU0ou2V6ANOBIUA/IAWYeqSPd87Ncs7lOOdyevfu3doYInFjeUE5Ly0sYFjvFL5z4mC/40gbF8obqmcBG51zJQBm9hpwMpBmZh28rfcBwJbQY4rEr/K9tXy0toSP1gT/gr3xtGE+J5J4EEq5bwJONLMuwF5gCrAQ+AC4lOARM9cCc0INKRLPnvw0r+EMj71Skrn0uAE+J5J40Opyd859YWavAIuBOmAJMAv4NzDbzH7pjT0RjqAi8WTFlnLuem0ZtXWO4ooq0rsm8/KNJ9EzJRkz8zuexIGQjnN3zt0L3HvAcC4wMZSfKxLPisqreO6LfFZs2cXZ2RkMSU9h0rBeDElP8TuaxBF9QlUkyn700lI+21BGz5RkZn3nOG2pS0ToGqoiUfTPJVv4bEMZpx/Vm7m3naJil4jRlrtIFOSW7GZHZQ1PfpYHwLWTssjQNU8lglTuIhFWuruasx76iEDwDL5cMLYfZxzdx99QEvdU7iIRVrBjLwEHt58zkjED0hjdv7vfkaQdULmLRFAg4Fi5tRyA00b2YfQAFbtEh95QFYmgZ+bn89N/rMAMMtO0j12iR+UuEkH/WraVzkmJzP7eiaR37eh3HGlHVO4iEVS+t5bOyYmcMFRXUpLoUrmLRNC2imrOH93X7xjSDukNVZEwKCqv4ol5udTWu/3Gd1bWkpGqfe0SfSp3kTB4ZdFmHv9kI6mdOtD4M6e9UpI5bnAP33JJ+6VyFwnRa4sL+O3ba+mUlMCye8/RKQUkJqjcRQ5QvreWZ+fnU1MXOKL5560vBeCxq3USMIkdKneRA7y1oogH31rTosecNrK3TikgMUXlLtJIRVUtd7y6DIA1v5xKxw6JPicSaR2Vuwiwq6qWfy7ZwpYdewE4f3RfFbu0aSp3EYLnWf/ZnJUAJCUat59zlM+JREKjcpd25YYnF7B8S/lB43uq6+iQYCz6n7PpmJRApyRttUvbpnKXuLe3pp63vy6iui7A+6u3MXZgGtmZ3Q6aL7tfN7p3SfIhoUj4qdwl7r3+1RbufHV5w/3/PnUo54/O9DGRSOSp3CWuLN28k7teXUZt/TfHqO+srMUMPvjx6XRKSqRvd50OQOJfSOVuZmnAX4FRgANuANYALwJZQB5wuXNuR0gpRZpRUlHNkk3fvLz+s7KI1UUVfGt0Jo3PA3BM31Sy0lN8SCjij1C33B8B/uOcu9TMkoEuwE+A95xz95vZXcBdwJ0hLkekST9/YyX/Xla431jfbp3409UTfEokEhtaXe5m1h04FbgOwDlXA9SY2XTgdG+2p4APUblLGLy4YBMvLti839ja4t0cN7gH/+fCYxvGtNtFJLQt9yFACfB3MxsLLAJuAzKcc/s2pYqAjKYebGYzgZkAgwYNCiGGtAcrt5bz1Gf5bC3fu98FpscPSuOK4wcyShedFtlPKOXeAZgAfN8594WZPUJwF0wD55wzM9fUg51zs4BZADk5OU3OIwLBYv/Wo/MAuOy4ATx42VifE4nEvlDKvQAocM594d1/hWC5F5tZpnOu0MwygW2hhpT2acmmHfzh/fWUVFQD8JtLxjBVVzUSOSKtLnfnXJGZbTazo5xza4ApwNfe17XA/d73OWFJKnGpuq6ejaV7mpz27PxNfLS2hFH9unHm0X24cFw/fXJU5AiFerTM94HnvCNlcoHrCV6X9SUzmwHkA5eHuAyJYz/750peXLi52ekj+nRlzi2To5hIJD6EVO7OuaVAThOTpoTycyU+/fH9dazbtnu/sc82lHFURio/OGtEk48Z2Tc1GtFE4o4+oSph5ZyjeFc1Abf/e+Q1dQF++/ZaeqYk063TNy+7lORErjh+IOfpdAAiYaVyl7Ca9XEu981d3ez0e6Ydw8XjB0QxkUj7pHKXFnvjq618uXF7k9Pm55bRKyWZO6YefD705A4JnDdKW+gi0aBylxaprqvnvjdXUbqnhq4dm375TB3VlyuO1wfTRPykcpcjtqmskrMe/oiaugC3nDGc28/V1YpEYpXKXQ6rrj7AE/M2snLrLmrqAsw8dSjXTBrsdywROQSVuxzWVwXl3Dd3NQkGvVM7cvMZw+neWVcsEollKnc5pA0lu7nksc8A+Petp3BME5enE5HYo3KXJr29soj8skpWF1UAcOXxAzkqQx8oEmkrVO5ykKraem58dhEB73NIaV2SuGdaNgkJdugHikjMULkLLy3czANzV7PvM6UB5wg4+MVFo7h4fH+SExNI7pDga0YRaRmVuzD7y03U1geYPq5/w1jHDgl8a3Rms8eyi0hs02+usLF0Dz1TkvnFRaP8jiIiYaK/tdu5tcUV7Kis5dxRugiGSDxRubdzsz7OBWDy8HSfk4hIOKnc27GSimpeWVQAwIlDe/mcRkTCSeXeTlXW1PHtxz4F4L5vjyYpUS8FkXiiN1TbkVWFu9ixpwaADaV72Lx9LxndOjL1WO1vF4k3Kvd2YtuuKs5/9BMOuEASz333BHqkJPsTSkQiRuXeTnyeW4Zz8JPzj2bsgDQAUjslMbyPTikgEo9U7u3EovwdAEwf15+Mbp18TiMikRbyu2hmlmhmS8zsX979IWb2hZmtN7MXzUx/8/usqraepz/PJynRVOwi7UQ4DpG4DVjV6P4DwMPOueHADmBGGJYhrbR40w6+/8ISAK4/eYjPaUQkWkIqdzMbAHwL+Kt334AzgVe8WZ4CLgplGdJ6e6rreOqzPD5YvY1j+3XjyuMH+h1JRKIk1H3uvwfuAPa9K9cL2Omcq/PuFwD9m3qgRN60P8xjY+kexgzozuu3TPY7johEUau33M1sGrDNObeolY+faWYLzWxhSUlJa2NIM1ZsKWdj6R7OG9WX31021u84IhJloeyWORm40MzygNkEd8c8AqSZ2b6/CAYAW5p6sHNulnMuxzmX07t37xBiSFOe/3ITANdMymKErqAk0u60utydc3c75wY457KAK4H3nXNXAx8Al3qzXQvMCTmltIhzjue/2ERWry5MGqZzxoi0R5E4zv1OYLaZ/RJYAjwRgWXIATZvr2T2gk3UB6C2PgBATlZPn1OJiF/CUu7OuQ+BD73bucDEcPxcOXKzF2ziTx9saLgcXmrHDlw1UUfHiLRX+oRqG/S/nlvEm8uLDhrvn9aZT+8604dEIhJrVO5tjHOON5cXMWZAd844qs9+08YPSvMplYjEGpV7G5NXVglAzuCe/PDskT6nEZFYpXKPcY+8u44n5uU23K8PBM/Ze1Z2n+YeIiKico9lC/O2M2fpFlI7JXF2dkbDeNeOHThucA8fk4lIrFO5x7CZzyxi+54arjspi59feKzfcUSkDVG5x5j3Vxfzf9/4mnrn2L6nhptOH8b/Pucov2OJSBujco8xLy0oYOvOKqaNyWTSUOOKnIEkJJjfsUSkjVG5x5jVRbvo0jGRh64Y53cUEWnDVO4++mhtCb99aw2BRletLtixlwvG9vMxlYjEA5W7D3JLdlOwYy/PfJ7PmuIKTh2R3jCtX1pnrj5hkI/pRCQeqNyjzDnHRX/6lF1VweuZjB2Yxl+vPd7nVCISb1TuUbRk0w5++/YadlXVcf3JWUwbk8ngXil+xxKROKRyj4K6+gCriyp4dv4mPt9QxknDenHVxEGM1EU0RCRCVO5R8PyXm/jZnJUADE1P4fnvnehzIhGJdyr3KPjXskI6JSXw56snMKx3V7/jiEg7oHKPoOq6egp3VrF9Tw2dkxI58+iMwz9IRCQMVO4RdOsLS3hrZTEA10wa7HMaEWlPVO4R8sZXW3lrZTHjB6Vx7aQsJjc6ll1EJNJU7mGyq6qWuvpvPmn64FtrALjy+IFcNL6/X7FEpJ1SuYfBB2u2cf3fFxw0/r1ThnDF8fq0qYhEn8o9DH7/zloA7pmWTQfvDI4JBuce29fPWCLSjrW63M1sIPA0kAE4YJZz7hEz6wm8CGQBecDlzrkdoUeNTVW19XxVUA7AjMlDfE4jIhKUEMJj64AfO+eygROBm80sG7gLeM85NwJ4z7sfl6rr6jnuF+8A8JtLx/icRkTkG60ud+dcoXNusXe7AlgF9AemA095sz0FXBRqyFj15Kd57KmpJzuzG1NHaReMiMSOULbcG5hZFjAe+ALIcM4VepOKCO62aeoxM81soZktLCkpCUeMqKoPOO6buxqABy8bQ7dOST4nEhH5RshvqJpZV+BV4AfOuV1m31wSzjnnzMw19Tjn3CxgFkBOTk6T88SizzaUctOzi6mtDwDw8wuyObZfd59TiYjsL6RyN7MkgsX+nHPuNW+42MwynXOFZpYJbAs1ZCzILdnNp+tL+WRdKeV7a5kxeQjJHRKYpqsmiUgMCuVoGQOeAFY55x5qNOl14Frgfu/7nJASxohfv7mKd1cF/58akp7CPdOyfU4kItK8ULbcTwa+Ayw3s6Xe2E8IlvpLZjYDyAcuDy2iv1ZsKeeW5xezdWcVp4xI5+ErxpHaSR8PEJHY1uqWcs7NA6yZyVNa+3NjRfGuKubnlvHJulLyyiq5eHx/LssZQHrXjn5HExE5LG2CNuO+N1fxz6VbAeiZkszvLhtLQkJz/5eJiMQWlXsz3lhWyLiBaTx0+Vh6pXRUsYtIm9Iuy72wfC+riyqane6coz7g6JPakaG6cpKItEHtstxvfWEJC/IOf7qb80dnRiGNiEj4tbty/3R9KQvydnDWMX24+Yzhzc6XlJhAdma3KCYTEQmfuC/3jaV72L6npuH+05/nAXDNpCzGD+rhTygRkQiL63Iv3V3NlN99SOCAkxuMG5jGqSN7+xNKRCQK4rbcV2wp5/65qwk4+PHZIxkzMK1h2sgMvUkqIvEtbsv9X8sK+XRDKcdn9eDKiYPonaoPH4lI+xG35b6toop+3Tvz8o0n+R1FRCTq4qrcnXM89M5atu6s4vMNZWR06+R3JBERX8RVuReWV/GH99fTo0sSXZI7cHZ2k9cJERGJe3FV7s/MzwfgoSvGccZRfXxOIyLin7BcZi8WbN25l8c+3ADA5OHpPqcREfFX3JT7P5ZsAeDi8f1JSoybf5aISKvETQv+a1khnZMSefiKcX5HERHxXdyUeyDg6Jemo2NERCCOyr24ooqThmlfu4gIxEm5L8rfzs7KWl0CT0TEExflvqygHICzsnX4o4gIxEm5b6uoJinRdP51ERFPXJT73OWFpHftiJmucyoiAhEsdzObamZrzGy9md0VqeWs2FJOXlkl3TsnRWoRIiJtTkTK3cwSgT8B5wHZwFVmlh2JZf33M4sAuHBcv0j8eBGRNilSW+4TgfXOuVznXA0wG5ge7oV8uGYbW3buZfq4ftx02rBw/3gRkTYrUuXeH9jc6H6BN9bAzGaa2UIzW1hSUtKqhaR2SmLamExuOn2Y9reLiDTi21khnXOzgFkAOTk57jCzN+m4wT04brAuci0icqBIbblvAQY2uj/AGxMRkSiIVLkvAEaY2RAzSwauBF6P0LJEROQAEdkt45yrM7NbgLeAROBvzrmVkViWiIgcLGL73J1zbwJvRurni4hI8+LiE6oiIrI/lbuISBxSuYuIxCGVu4hIHDLnWvX5ofCGMCsB8lv58HSgNIxxwiVWc0HsZlOullGulonHXIOdc72bmhAT5R4KM1vonMvxO8eBYjUXxG425WoZ5WqZ9pZLu2VEROKQyl1EJA7FQ7nP8jtAM2I1F8RuNuVqGeVqmXaVq83vcxcRkYPFw5a7iIgcQOUuIhKH2nS5R+si3M0se6CZfWBmX5vZSjO7zRv/uZltMbOl3tf5jR5zt5d1jZmdG8FseWa23Fv+Qm+sp5m9Y2brvO89vHEzs0e9XMvMbEKEMh3VaJ0sNbNdZvYDP9aXmf3NzLaZ2YpGYy1eP2Z2rTf/OjO7NkK5HjSz1d6y/2Fmad54lpntbbTe/tLoMcd5z/96L3tIlylrJleLn7dw/742k+vFRpnyzGypNx7N9dVcN0T3Neaca5NfBE8lvAEYCiQDXwHZUVx+JjDBu50KrCV4MfCfA7c3MX+2l7EjMMTLnhihbHlA+gFjvwHu8m7fBTzg3T4fmAsYcCLwRZSeuyJgsB/rCzgVmACsaO36AXoCud73Ht7tHhHIdQ7Qwbv9QKNcWY3nO+DnfOllNS/7eRHI1aLnLRK/r03lOmD674Cf+bC+muuGqL7G2vKWe1Quwt0c51yhc26xd7sCWMUB14k9wHRgtnOu2jm3EVhP8N8QLdOBp7zbTwEXNRp/2gXNB9LMLDPCWaYAG5xzh/pUcsTWl3PuY2B7E8tryfo5F3jHObfdObcDeAeYGu5czrm3nXN13t35BK9q1iwvWzfn3HwXbIinG/1bwpbrEJp73sL++3qoXN7W9+XAC4f6GRFaX811Q1RfY2253A97Ee5oMbMsYDzwhTd0i/fn1d/2/elFdPM64G0zW2RmM72xDOdcoXe7CMjwIdc+V7L/L53f6wtavn78WG83ENzC22eImS0xs4/M7BRvrL+XJRq5WvK8RXt9nQIUO+fWNRqL+vo6oBui+hpry+UeE8ysK/Aq8APn3C7gMWAYMA4oJPinYbRNds5NAM4DbjazUxtP9LZQfDkG1oKXXbwQeNkbioX1tR8/109zzOynQB3wnDdUCAxyzo0HfgQ8b2bdohgp5p63A1zF/hsQUV9fTXRDg2i8xtpyuft+EW4zSyL45D3nnHsNwDlX7Jyrd84FgMf5ZldC1PI657Z437cB//AyFO/b3eJ93xbtXJ7zgMXOuWIvo+/ry9PS9RO1fGZ2HTANuNorBbzdHmXe7UUE92eP9DI03nUTkVyteN6iub46AN8GXmyUN6rrq6luIMqvsbZc7r5ehNvbp/cEsMo591Cj8cb7qy8G9r2T/zpwpZl1NLMhwAiCb+SEO1eKmaXuu03wDbkV3vL3vdt+LTCnUa5rvHfsTwTKG/3pGAn7bVH5vb4aaen6eQs4x8x6eLskzvHGwsrMpgJ3ABc65yobjfc2s0Tv9lCC6yfXy7bLzE70XqPXNPq3hDNXS5+3aP6+ngWsds417G6J5vpqrhuI9msslHeF/f4i+C7zWoL/C/80ysueTPDPqmXAUu/rfOAZYLk3/jqQ2egxP/WyriHEd+QPkWsowSMRvgJW7lsvQC/gPYlMYagAAACkSURBVGAd8C7Q0xs34E9eruVATgTXWQpQBnRvNBb19UXwP5dCoJbgfswZrVk/BPeBr/e+ro9QrvUE97vue439xZv3Eu/5XQosBi5o9HNyCJbtBuCPeJ9ED3OuFj9v4f59bSqXN/4kcOMB80ZzfTXXDVF9jen0AyIicagt75YREZFmqNxFROKQyl1EJA6p3EVE4pDKXUQkDqncRUTikMpdRCQO/X9qyp/O7B+JfwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ySu7WO95QpCA", + "colab_type": "text" + }, + "source": [ + "## Linear Posterior Policy with the Covertype Dataset based Bandit" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "-sZC2jIJ3Ihm", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "6d23fc6c-9965-4e3e-aec5-258d4f5bfe4a" + }, + "source": [ + "from genrl.agents import LinearPosteriorAgent\n", + "from genrl.trainers import DCBTrainer\n", + "from genrl.utils.data_bandits import CovertypeDataBandit\n", + "\n", + "bandit = CovertypeDataBandit(download=True)\n", + "agent = LinearPosteriorAgent(bandit)\n", + "trainer = DCBTrainer(agent, bandit)\n", + "results = trainer.train(10000)\n", + "\n", + "plt.plot(results[\"cumulative_regrets\"])" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz to /content/data/Covertype/covtype.data.gz\n", + "\n", + "Started at 25-08-20 18:24:09\n", + "Training LinearPosteriorAgent on CovertypeDataBandit for 10000 timesteps\n", + "timestep regret/regret reward/reward regret/cumulative_regret reward/cumulative_reward regret/regret_moving_avg reward/reward_moving_avg \n", + "100 0 1 86 14 0.86 0.14 \n", + "200 1 0 168 32 0.84 0.16 \n", + "300 1 0 252 48 0.84 0.16 \n", + "400 1 0 339 61 0.8475 0.1525 \n", + "500 1 0 424 76 0.848 0.152 \n" + ], + "name": "stdout" + }, + { + "output_type": "stream", + "text": [ + "/content/genrl/genrl/agents/bandits/contextual/linpos.py:92: RuntimeWarning: covariance is not positive-semidefinite.\n", + " for i in range(self.n_actions)\n" + ], + "name": "stderr" + }, + { + "output_type": "stream", + "text": [ + "600 1 0 509 91 0.846 0.154 \n", + "700 1 0 557 143 0.778 0.222 \n", + "800 1 0 588 212 0.672 0.328 \n", + "900 0 1 634 266 0.59 0.41 \n", + "1000 0 1 678 322 0.508 0.492 \n", + "1100 0 1 715 385 0.412 0.588 \n", + "1200 0 1 752 448 0.39 0.61 \n", + "1300 0 1 798 502 0.42 0.58 \n", + "1400 0 1 827 573 0.386 0.614 \n", + "1500 1 0 856 644 0.356 0.644 \n", + "1600 1 0 904 696 0.378 0.622 \n", + "1700 0 1 942 758 0.38 0.62 \n", + "1800 1 0 985 815 0.374 0.626 \n", + "1900 1 0 1023 877 0.392 0.608 \n", + "2000 0 1 1053 947 0.394 0.606 \n", + "2100 1 0 1101 999 0.394 0.606 \n", + "2200 0 1 1127 1073 0.37 0.63 \n", + "2300 1 0 1176 1124 0.382 0.618 \n", + "2400 1 0 1223 1177 0.4 0.6 \n", + "2500 0 1 1262 1238 0.418 0.582 \n", + "2600 1 0 1307 1293 0.412 0.588 \n", + "2700 0 1 1348 1352 0.442 0.558 \n", + "2800 1 0 1392 1408 0.432 0.568 \n", + "2900 1 0 1440 1460 0.434 0.566 \n", + "3000 1 0 1477 1523 0.43 0.57 \n", + "3100 1 0 1526 1574 0.438 0.562 \n", + "3200 1 0 1561 1639 0.426 0.574 \n", + "3300 0 1 1597 1703 0.41 0.59 \n", + "3400 0 1 1638 1762 0.396 0.604 \n", + "3500 0 1 1678 1822 0.402 0.598 \n", + "3600 0 1 1712 1888 0.372 0.628 \n", + "3700 0 1 1757 1943 0.392 0.608 \n", + "3800 0 1 1805 1995 0.416 0.584 \n", + "3900 1 0 1841 2059 0.406 0.594 \n", + "4000 0 1 1873 2127 0.39 0.61 \n", + "4100 1 0 1907 2193 0.39 0.61 \n", + "4200 0 1 1950 2250 0.386 0.614 \n", + "4300 0 1 1986 2314 0.362 0.638 \n", + "4400 0 1 2026 2374 0.37 0.63 \n", + "4500 0 1 2067 2433 0.388 0.612 \n", + "4600 1 0 2107 2493 0.4 0.6 \n", + "4700 0 1 2142 2558 0.384 0.616 \n", + "4800 0 1 2186 2614 0.4 0.6 \n", + "4900 0 1 2228 2672 0.404 0.596 \n", + "5000 1 0 2268 2732 0.402 0.598 \n", + "5100 0 1 2301 2799 0.388 0.612 \n", + "5200 0 1 2335 2865 0.386 0.614 \n", + "5300 0 1 2375 2925 0.378 0.622 \n", + "5400 0 1 2410 2990 0.364 0.636 \n", + "5500 1 0 2456 3044 0.376 0.624 \n", + "5600 1 0 2499 3101 0.396 0.604 \n", + "5700 1 0 2536 3164 0.402 0.598 \n", + "5800 0 1 2576 3224 0.402 0.598 \n", + "5900 0 1 2610 3290 0.4 0.6 \n", + "6000 0 1 2654 3346 0.396 0.604 \n", + "6100 1 0 2701 3399 0.404 0.596 \n", + "6200 1 0 2745 3455 0.418 0.582 \n", + "6300 1 0 2789 3511 0.426 0.574 \n", + "6400 1 0 2824 3576 0.428 0.572 \n", + "6500 0 1 2867 3633 0.426 0.574 \n", + "6600 1 0 2909 3691 0.416 0.584 \n", + "6700 0 1 2952 3748 0.414 0.586 \n", + "6800 0 1 2997 3803 0.416 0.584 \n", + "6900 1 0 3040 3860 0.432 0.568 \n", + "7000 1 0 3076 3924 0.418 0.582 \n", + "7100 0 1 3118 3982 0.418 0.582 \n", + "7200 0 1 3158 4042 0.412 0.588 \n", + "7300 0 1 3205 4095 0.416 0.584 \n", + "7400 0 1 3241 4159 0.402 0.598 \n", + "7500 0 1 3280 4220 0.408 0.592 \n", + "7600 0 1 3320 4280 0.404 0.596 \n", + "7700 1 0 3363 4337 0.41 0.59 \n", + "7800 0 1 3403 4397 0.396 0.604 \n", + "7900 1 0 3442 4458 0.402 0.598 \n", + "8000 1 0 3477 4523 0.394 0.606 \n", + "8100 0 1 3523 4577 0.406 0.594 \n", + "8200 1 0 3568 4632 0.41 0.59 \n", + "8300 1 0 3607 4693 0.408 0.592 \n", + "8400 0 1 3658 4742 0.432 0.568 \n", + "8500 0 1 3699 4801 0.444 0.556 \n", + "8600 0 1 3745 4855 0.444 0.556 \n", + "8700 0 1 3790 4910 0.444 0.556 \n", + "8800 1 0 3828 4972 0.442 0.558 \n", + "8900 0 1 3859 5041 0.402 0.598 \n", + "9000 0 1 3894 5106 0.39 0.61 \n", + "9100 0 1 3939 5161 0.388 0.612 \n", + "9200 1 0 3979 5221 0.378 0.622 \n", + "9300 0 1 4015 5285 0.374 0.626 \n", + "9400 0 1 4058 5342 0.398 0.602 \n", + "9500 0 1 4092 5408 0.396 0.604 \n", + "9600 0 1 4133 5467 0.388 0.612 \n", + "9700 0 1 4183 5517 0.408 0.592 \n", + "9800 1 0 4225 5575 0.42 0.58 \n", + "9900 0 1 4256 5644 0.396 0.604 \n", + "10000 1 0 4298 5702 0.412 0.588 \n", + "Training completed in 93 seconds\n", + "Final Regret Moving Average: 0.412 | Final Reward Moving Average: 0.588\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[]" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 50 + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3dd3wVVd7H8c8PQu810gMERHqJgF3AtaKIomJZsS27j1h3bazr2lYfFV37iqzuCjYEGyiuiIAVERMILRAIISGFDklMQvp5/sjgE1mRADeZm3u/79frvjL3zOTe38mEb4Zzz8yYcw4REQkPtfwuQEREqo9CX0QkjCj0RUTCiEJfRCSMKPRFRMJIhN8F/JrWrVu7qKgov8sQEalR4uLidjnn2vzSuqAO/aioKGJjY/0uQ0SkRjGz1IOt0/COiEgYUeiLiIQRhb6ISBhR6IuIhBGFvohIGFHoi4iEEYW+iEgYUeiLiASR0jLHOz9s4ZPVW6vk9YP65CwRkXASn5bFI/MS+CFlL6P7t+Pcfu0C/h4KfRERn+UWlvDi4iSmfrmJRnUjmDKuP+OGdKyS91Loi4j4ZFt2Aa8tSeGVr5MpKXOMG9KR+87rTbOGdarsPRX6IiLVrKzM8dayLTz40VqKSx3n9juGq0+IYni3VlX+3gp9EZFqtCo9iztmr2TD9lxO7N6KR8f2I6p1o2p7f4W+iEg1KCwp5dF565ixNJXIJvV56pIBXDCwPXVqV+8kSoW+iEgV25FTwMTX44hPy+K3w7twx5nHVum4/a9R6IuIVKEvN+zkjtkr+bGgmBevGMx5/QM/DfNwKPRFRKrAnrwipsxP5O1lW+jWuhGvXz+UXsc09bsshb6ISCBlZO1jyqfrmbsykzIH153UlbvOPpb6dWr7XRqg0BcRCYi9eUU8u3Ajb36finNw+dDOXD60M307NPO7tJ+pdOibWW0gFshwzo02s67ATKAVEAf81jlXZGb1gBnAEGA3cJlzLsV7jcnA9UApcItzbn4gOyMi4of3l6fzlw/XkF9UykWDO3D7GT3p1LKh32X9osM50r8VWAfsH5R6HHjaOTfTzKZSHuYveV/3OueizWy8t91lZtYbGA/0AdoDn5tZT+dcaYD6IiJSrbLyi7hvzlo+WplJTJcWPHBBn6A7sj9QpSaImllH4DzgFe+5ASOBd71NpgMXestjvOd460d5248BZjrnCp1zm4EkYGggOiEiUp2y84t56KMEhj26kHmrMrllZDRv3DAs6AMfKn+k/wxwF9DEe94KyHLOlXjP04EO3nIHIA3AOVdiZtne9h2ApRVes+L3/MTMJgITATp37lzpjoiIVId5q7Zy/9y17M4rZMyA9vzh9O5BMSunsg4Z+mY2GtjhnIszs9OruiDn3DRgGkBMTIyr6vcTEamM1N15TP1yE28vS6NnZGNeu/b4GnFkf6DKHOmfBFxgZucC9Skf038WaG5mEd7Rfkcgw9s+A+gEpJtZBNCM8g9097fvV/F7RESC0r6iUh7/dP1Ps3KuOTGK+0b3pnYt87u0I3LIMX3n3GTnXEfnXBTlH8Qucs5dCSwGxnmbTQDmeMtzved46xc555zXPt7M6nkzf3oAywLWExGRAFu3NYdLX/6O15akMHZQB76+ewQPXNCnxgY+HN08/buBmWb2N2AF8KrX/irwupklAXso/0OBc26tmc0CEoASYJJm7ohIMNq8K48XFiXx0apMGteLYOpVQzi77zF+lxUQVn4QHpxiYmJcbGys32WISJjYlVvIB8szePKzRGqZMbp/O+4+pxetG9fzu7TDYmZxzrmYX1qnM3JFJOxlZu3jxcVJfLAig/yiUk7p0ZonxvWnXbMGfpcWcAp9EQlbW7P38e9vU5i+JAUHjOrVlhtPj6Zfx5o3K6eyFPoiEnZKyxxPfpbIq19vpqi0jLGDOnDzyGi6tWnsd2lVTqEvImFlx48F3PTmCpal7GHsoA7cNDKa7mEQ9vsp9EUkLDjnePP7LTzx6XqKSst4dGw/rhgWfmf9K/RFJOQl78zlLx+uYcmm3ZzQrRX3nndcjTybNhAU+iIS0uLTsvifN+LIKyzhwQv68NvhXahVg0+uOloKfREJSSm78vjr3LV8tWEnkU3rMXPiCfRuX3MujFZVFPoiElJKyxwPfbSWGUtTaVinNreO6sF1J3WlWcM6fpcWFBT6IhIysvOLuX76D8Sm7uXyoZ2ZNKI7HVsE5x2s/KLQF5GQsG5rDv/zRhyZWQU8eckAxg3p6HdJQUmhLyI1WtqefF79ZjOvL02lRcO6vHHDMIZ2bel3WUFLoS8iNZJzjsWJO7htZjx5RaWc378d957XmzZNatbF0aqbQl9EapzduYU88FECH63MJLptY/55dQxdWzfyu6waQaEvIjWGc47Zsek8PC+B3MISbhoRzaQR0TSoW9vv0moMhb6I1AjxaVk8Mi+BH1L2Mrhzcx6/uD89Ipv4XVaNo9AXkaDmnOO5hUk8/fkGmtaP4P7ze3P1CVE1+paFflLoi0jQys4vZspn63lj6RbO69+ORy7sS/OGdf0uq0ZT6ItI0NmeU8DTCzbwblw6JWWO353SlT+fexxmOro/Wgp9EQkaJaVlvLYkhZe+2EROQTEXD+7IhYM6cEL3Vn6XFjIU+iISFJZt3sNj/1nH8i1ZDO7cnEcv6kevY3SBtEBT6IuIr+JS9zJl/nqWJu+hVaO6PHPZQC4c1MHvskKWQl9EfPP8wo08tWADzRrU4f7ze3PZ8Z1oWFexVJX00xWRapdTUMy9H6zho5WZnNUnkimXDKBpfV36uDoo9EWk2pSVOV5bksLzizaSta+YW0f14MYR3akXoTNqq4tCX0SqRUbWPh6dt455q7cyNKold519LDFRuhpmdVPoi0iV2ptXxKOfrOODFRmYwR1n9mTSiGjNufeJQl9EqoRzjveXZ/DoJ+vYk1/ExYM7ctOIaKJ0NUxfKfRFJODiUvcw+f3VbNiey4BOzZl+3VD6dmjmd1mCQl9EAsQ5x+qMbF7+Mpl5q7fSpkk9Hhnbl/HHd9bF0YKIQl9EjlpJaRl/+XANM39Io25ELW4aEc3vT+tGE03DDDoKfRE5Ksk7c/nT7JWs2JLF9Sd35cbTu9OqsW5ZGKwU+iJyRHILS3hmwQZeW5JCgzq1eXb8QC4Y0F6zcoKcQl9EDtuna7bxwNy1bMspYOygDtxzTi8im9b3uyypBIW+iFRaXOoeZnyXypz4TI6NbMLzVwzieJ1gVaMo9EXkkIpLy5j6xSaeWbgRA353SlfuPKsXdSNq+V2aHCaFvogcVEFxKdO+Sua95emk7s7n7D7H8NjF/XTLwhrskH+mzay+mS0zs5VmttbMHvTau5rZ92aWZGbvmFldr72e9zzJWx9V4bUme+2JZnZWVXVKRI7elt35jJu6hL8v2ECrRnV54YpB/OPKwQr8Gq4yR/qFwEjnXK6Z1QG+MbP/AH8EnnbOzTSzqcD1wEve173OuWgzGw88DlxmZr2B8UAfoD3wuZn1dM6VVkG/ROQIlZU5/r5gA9O+SqZeRC2mXjWEs/se43dZEiCHPNJ35XK9p3W8hwNGAu967dOBC73lMd5zvPWjrHwO1xhgpnOu0Dm3GUgChgakFyISECvTsjj/hW94YXESv+kTybxbTlHgh5hKjembWW0gDogGXgQ2AVnOuRJvk3Rg//3NOgBpAM65EjPLBlp57UsrvGzF76n4XhOBiQCdO3c+zO6IyJEoKS1j6pebePrzjbRuXJcnxvXnkiEdNec+BFUq9L0hmIFm1hz4AOhVVQU556YB0wBiYmJcVb2PiJTLzNrHjW8uJz4ti3P7HcOjY/VBbSg7rNk7zrksM1sMnAA0N7MI72i/I5DhbZYBdALSzSwCaAbsrtC+X8XvEZFq5pzjrWVbePbzjeQVluiM2jBRmdk7bbwjfMysAfAbYB2wGBjnbTYBmOMtz/We461f5JxzXvt4b3ZPV6AHsCxQHRGRytu0M5eLXlrCvR+soXXjerzz+xMYM7CDAj8MVOZIvx0w3RvXrwXMcs59bGYJwEwz+xuwAnjV2/5V4HUzSwL2UD5jB+fcWjObBSQAJcAkzdwRqV7FpWXMXLaFJ+YnUruW8b8X9WP88Z0U9mHEyg/Cg1NMTIyLjY31uwyRkLBs8x7ueW8VybvyGNKlBc9cNpBOLRv6XZZUATOLc87F/NI6nZErEuJ+LCjmw/hMHpy7lsim9fnn1TGccVxbHd2HKYW+SIgqKS3jgxXl96jdm1/M8VEtmHrVEF3rPswp9EVCjHOOWbFpPL1gI9tyCujdrikvXjmY4V1bUUu3LQx7Cn2REJK2J58/zornh5S99GnflAfH9OHM3pEaypGfKPRFQoBzjn9/m8JTnyVSXOp4dGw/Lju+k25ILv9FoS9SwyVu+5E7Zq9kdUY2J0W34oHz+9AjsonfZUmQUuiL1FBlZY43v0/lb/PW0bBubR6/uB/jhujoXn6dQl+kBtq8K4+7313FspQ9nBTdiicvGUC7Zg38LktqAIW+SA3inOODFRk8+FECzjkeGduXK4Z21ge1UmkKfZEaIiEzh+cXbeQ/a7YxqHNznr1sEJ1b6YxaOTwKfZEg55zjjaWp3D93LfUianPbGT24eWQPjd3LEVHoiwSx5Vv28tIXm1iQsJ1TerTmufGDaNFI17qXI6fQFwlCKbvymPrlJt6JTaNu7Vrcdfax/OHU7jqjVo6aQl8kiJSUljFlfiL//jYFh+Pq4V2446xjaVK/jt+lSYhQ6IsEgbIyx2cJ23hifiLJO/O4cGB77j6nl6ZhSsAp9EV8lro7j8nvr2bJpt1EtWrIP64czLn92vldloQohb6IT0rLHE9+lsir32ymthkPjenDFUM7E1H7kHcxFTliCn0RH6xOz+avc9ewYksWFw5szx1nHUvHFppzL1VPoS9SjfIKS3jwo7XMik2nSf0IHh3bj8uH6h61Un0U+iLVZHV6Nne+u5LE7T9yw8lduXFENC01516qmUJfpIrlFBQz5dNE3lq2hRYN6/DK1TGMOi7S77IkTCn0RapITkExc+IzeX7hRnbmFnLVsC7cPDKatk3r+12ahDGFvkiA5RWW8PJXybz6dTJ5RaX0jGzMy78dwqDOLfwuTUShLxJICZk53P5OPInbf+TM3pFcc1KUbkguQUWhLxIAO3IKeHFxEq8vTaVxvQimXzeU03q28bsskf+i0Bc5CvlFJTwybx2zY9MpdY4xAztw3+jempUjQUuhL3IEtuzO55VvkpkTn0lOQTGXxXTiD6d1J6p1I79LE/lVCn2Rw7Ajp4DnFyXxzg9pFJeVcWbvSK4c1oVTNZQjNYRCX6QSnHO8tWwLD85NoKSsjEtjOnHTyGhdOkFqHIW+yCFszyngtpnxfJe8mxO7t+KBC/rQM7KJ32WJHBGFvshBlJU5nlm4kZe/3IQD7j+/N1cN70IdXQVTajCFvsgvWJy4g+cXbmT5lizO7B3JXWf3IrptY7/LEjlqCn2RCvbmFfHSl5uY9lUybZvU4+ExfbhqeBddBVNChkJfBCguLeOFRUlM+yqZfcWljBvSkUfG9qVeRG2/SxMJKIW+hL2VaVn8+YPVrM3M4aw+kdw8sgd92jfV0b2EJIW+hK2s/CKe+XwjM75LoWWjejx1yQAuHtLR77JEqpRCX8KOc47Zsen8bV4CPxaWcOmQTtw7+jia1q/jd2kiVe6Qc8/MrJOZLTazBDNba2a3eu0tzWyBmW30vrbw2s3MnjOzJDNbZWaDK7zWBG/7jWY2oeq6JfLLlm/Zy/hpS7nrvVUce0wTPr75ZB4f11+BL2GjMkf6JcCfnHPLzawJEGdmC4BrgIXOucfM7B7gHuBu4Bygh/cYBrwEDDOzlsD9QAzgvNeZ65zbG+hOiRxod24hf5y1ki837KR5wzrcN7o3154YpUseS9g5ZOg757YCW73lH81sHdABGAOc7m02HfiC8tAfA8xwzjlgqZk1N7N23rYLnHN7ALw/HGcDbwewPyI/45xjTnwmU+Ynsiu3kNvP6Mk1J0bRrKGO7CU8HdaYvplFAYOA74FI7w8CwDZg/00/OwBpFb4t3Ws7WPuB7zERmAjQuXPnwylP5GcSMnO4b84a4lL30r1NI96eOJzBunuVhLlKh76ZNQbeA25zzuVUnM7mnHNm5gJRkHNuGjANICYmJiCvKeHDOcfXG3fxr28380XiTprUj+Dxi/txaUwnTcEUoZKhb2Z1KA/8N51z73vN282snXNuqzd8s8NrzwA6Vfj2jl5bBv8/HLS//YsjL13k53bnFnLH7JUsTtxJk3oRTBrRnetP7qYbmohUcMjQt/LDo1eBdc65v1dYNReYADzmfZ1Tof0mM5tJ+Qe52d4fhvnAo/tn+QBnApMD0w0Jd0k7fuSG6bFszS5g8jm9uPqEKBrU1dm0IgeqzJH+ScBvgdVmFu+1/ZnysJ9lZtcDqcCl3rpPgHOBJCAfuBbAObfHzB4GfvC2e2j/h7oiR2rzrjz+95N1LFi3nVaN6jL9uqEM79bK77JEgpaVT7IJTjExMS42NtbvMiQIrcnIZtpXycxdmUmDOrW5NKYj153clS6tdLtCETOLc87F/NI6nZErNcqyzXt45etkPkvYToM6tbnupK78/rRuRDat73dpIjWCQl9qhLQ9+Twwdy0L1++gaf0IbhoRze9O6ab59iKHSaEvQa20zPHQR2t5fWkqdWrX4vYzenLDKV1pVE+/uiJHQv9yJGjtyi1k4oxYlm/J4vKhnbhlVA/aNWvgd1kiNZpCX4LSrNg0nvh0PVn5xUwZ159xQzrq5CqRAFDoS1DZlVvIX+es4ZPV2xjQsRmvXTuUvh2a+V2WSMhQ6EtQKCktY3ZcOlPmJ5JbUMJtZ/Tg5pE9qK2rYIoElEJffFVcWsaM71KZ9tUmtucUMqBTcx65sK+O7kWqiEJffFFW5njj+1ReWJTEjh8L6dehGX8d3Ydz+x2jsXuRKqTQl2qXlV/ErTPj+XLDTgZ2as6DF/Th7L4Ke5HqoNCXarV4/Q7ufHcV2fuKePCCPlx9QheFvUg1UuhLtYhPy+KdH9J4e9kWekY25rVrj9e4vYgPFPpSpXblFvLEp+uZFZtOndrG+OM78dfze9Owrn71RPygf3lSJTKy9jF9SQrTl6RQWua4anhn7jjzWJo31A1NRPyk0JeA2pNXxIzvUnhhURIlZY7zB7TnxtO7c1y7pn6XJiIo9CWA1mRk8/vX48jI2sdJ0a34y3m9FfYiQUahL0fNOcc/v07myfkbaFI/gpkTh+vuVSJBSqEvR2V3biF/mr2SLxJ3MrJXWx6/uD9tmtTzuywROQiFvhyRPXlFvPbtZv69JIX8olL+Oro3154UpTn3IkFOoS+HxTnHm99vYcr8RLL3FXNazzbcc04vjd2L1BAKfam0HT8W8Of31/D5uu0M69qS+0b31glWIjWMQl8qZcWWvdz2TjyZWfu455xe/O6UbrrssUgNpNCXX5VXWMKTnyUy47tUmjeowxvXD2OYZuaI1FgKfTmoT1Zv5cn5iSTvyuO8/u3425i+tGikM2pFajKFvvyXfUWlPLtwI1O/3ES3No147drjOf3Ytn6XJSIBoNCXn5SWOd5etoWnF2xgd14R44Z05JGxfakXUdvv0kQkQBT6AsDOHwu5571VLFy/g/4dm/H0ZQM5tWcbv8sSkQBT6Ic55xz/WbON+z5cw4+FJUw+pxcTT+2mk6xEQpRCP4wt27yHf36dzIKE7fSMbMybvxtGr2N0kpVIKFPoh6GC4lKeX7SRl77YRKO6Edwyqgc3jYimbkQtv0sTkSqm0A8zazOzuXP2KhK25nDBgPY8elE/GtfTr4FIuNC/9jDyn9VbuWP2SurVqc1LVw7mnH7t/C5JRKqZQj8MZOcXc++Hq/l41Vb6dmjKqxOOJ7Jpfb/LEhEfKPRD2PacAj5ckcEb36eyNauAW0f1YJLG7kXCmkI/BO3IKeD5RUm8vWwLJWWOPu2bMuWGAbqblYgo9EPN69+lMGV+IvlFpVwS04kbTulK9zaN/S5LRIKEQj9E5BeVMPn91cyJz2RoVEseGduXHpFN/C5LRILMIQd3zexfZrbDzNZUaGtpZgvMbKP3tYXXbmb2nJklmdkqMxtc4XsmeNtvNLMJVdOd8FNUUsab36dyzrNfMyc+k1tGRvP2xOEKfBH5RZU50n8NeAGYUaHtHmChc+4xM7vHe343cA7Qw3sMA14ChplZS+B+IAZwQJyZzXXO7Q1UR8JNflEJT3yayKzYNPKLSunRtjHTrxvKabpejoj8ikOGvnPuKzOLOqB5DHC6tzwd+ILy0B8DzHDOOWCpmTU3s3betgucc3sAzGwBcDbw9lH3IMyUlTk+S9jG84uSWJtZfoLVBQPaM+q4trpejogc0pGO6Uc657Z6y9uASG+5A5BWYbt0r+1g7f/FzCYCEwE6d+58hOWFpjUZ2fxtXgJLk/fQsUUDXrhiEKP7t/e7LBGpQY76g1znnDMzF4hivNebBkwDiImJCdjr1mSFJaVMX5LCk59tAOC+0b357fAumm8vIoftSEN/u5m1c85t9YZvdnjtGUCnCtt19Noy+P/hoP3tXxzhe4eVlWlZ3D4rnuSdeYzq1ZYnxvWnVeN6fpclIjXUkR4qzgX2z8CZAMyp0H61N4tnOJDtDQPNB840sxbeTJ8zvTY5iLQ9+dz+TjxjXvyW/MJS/nl1DK9MiFHgi8hROeSRvpm9TflRemszS6d8Fs5jwCwzux5IBS71Nv8EOBdIAvKBawGcc3vM7GHgB2+7h/Z/qCs/55xjxnepTJmfSElZGb8/tRs3joimWYM6fpcmIiHAyifaBKeYmBgXGxvrdxnVJju/mDvfXclnCds5pUdrHh7Tl6jWjfwuS0RqGDOLc87F/NI6nZEbBHblFvLhigxe+mITu/OKuP2Mntw8MppatTQFU0QCS6HvozUZ2Tzz+QYWrd9BmYOhUS3554ReDO7cwu/SRCREKfR98HnCdmYsTeWbjTtpVDeCiad254IB7endXvenFZGqpdCvRvlFJdw5exXzVm+lXbP6TDy1OxNP7UbLRnX9Lk1EwoRCv5p8tWEnf5wVz568Im4d1YMbR3SnXkRtv8sSkTCj0K9i67fl8PDHCXybtJvoto35x5VDGNq1pd9liUiYUuhXkfyiEh76KIGZP6TRtH4Ed5zZk+tP7kaDujq6FxH/KPQDrKS0jHfj0nlqwQZ25RZyzYlR3DQymtY6k1ZEgoBCP4CWb9nLXz5YQ8LWHGK6tODFKwZrKEdEgopCPwD25BVx93urWJCwnZaN6vL3SwcwdlAHXd9eRIKOQv8orcnI5paZK8jYu49JI7rz+9O607S+rpMjIsFJoX+ECopLeW7hRl7+Kpmm9SN4/fphGsoRkaCn0D9MxaVlfLwqk38s3sTGHbmc178dD5zfhzZN9EGtiAQ/hX4lFRSXMjsunX9/s5nkXXlENq3Hi1cM5rz+7fwuTUSk0hT6lbA2M5t73lvN6oxsots2ZupVgznjuEgiaut2hSJSsyj0D2FOfAZ/mrWSZg3q8NQlA7hosGbliEjNpdD/Fcu37OWO2SuJiWrB1KuG0LyhLowmIjWbxicOYm9eEZPeXM4xzeor8EUkZOhI/xcUlZRx45vL2Z1bxPs3nqjAF5GQodD/BQ9/nMB3ybt56pIB9O3QzO9yREQCRsM7B/hs7TZeX5rKDSd35eIhHf0uR0QkoBT6FaTsyuNPs1fSp31T7jz7WL/LEREJOIW+p6C4lElvLaeWGVOvGqK7WolISNKYvueBuWtZm5nDqxNi6NSyod/liIhUCR3pA+/FpTPzhzQmjejOqOMi/S5HRKTKhH3oJ+/M5b45axjerSW3n9HT73JERKpUWIf+vqJSJr21groRtXjmskG6lo6IhLywHdN3zjH5/VWs35bDv645nmOa1fe7JBGRKhe2h7b/+jaFD+Mz+eMZPRlxbFu/yxERqRZhGfqr0rP430/W8ZvekUwaEe13OSIi1SbsQj97XzG3zoynTZN6TBnXn1q1dJlkEQkfYTWmX1Bcyu9mxJK+N58Z1w3ThdREJOyEVej/5cM1LNu8h2fHD+SE7q38LkdEpNqFzfDOByvSeTcunVtGRjNmYAe/yxER8UVYhP7azGzu/WANQ6NacsuoHn6XIyLim5AP/d25hUycEUezBnV47nKdgCUi4S2kx/SLS8uY9NZyduYW8t4fTtQJWCIS9qr9sNfMzjazRDNLMrN7qvK9Hpm3jqXJe3json7066g7YImIVGvom1lt4EXgHKA3cLmZ9a6K95oTn8FrS1K49qQoLhqsO2CJiED1H+kPBZKcc8nOuSJgJjAm0G+SsiuPez9YQ0yXFtx77nGBfnkRkRqrukO/A5BW4Xm61/YTM5toZrFmFrtz584jepNaZgzq3Jxn9cGtiMjPBF0iOuemOedinHMxbdq0OaLX6NyqIa9fP4wOzRsEuDoRkZqtukM/A+hU4XlHr01ERKpBdYf+D0APM+tqZnWB8cDcaq5BRCRsVes8fedciZndBMwHagP/cs6trc4aRETCWbWfnOWc+wT4pLrfV0REgvCDXBERqToKfRGRMKLQFxEJIwp9EZEwYs45v2s4KDPbCaQexUu0BnYFqJyaINz6C+pzuFCfD08X59wvnt0a1KF/tMws1jkX43cd1SXc+gvqc7hQnwNHwzsiImFEoS8iEkZCPfSn+V1ANQu3/oL6HC7U5wAJ6TF9ERH5uVA/0hcRkQoU+iIiYSQkQ786b75e1cysk5ktNrMEM1trZrd67S3NbIGZbfS+tvDazcye8/q+yswGV3itCd72G81sgl99qgwzq21mK8zsY+95VzP73uvXO96luTGzet7zJG99VIXXmOy1J5rZWf70pHLMrLmZvWtm681snZmdEAb7+Hbvd3qNmb1tZvVDbT+b2b/MbIeZranQFrD9amZDzGy19z3PmZkdsijnXEg9KL9k8yagG1AXWAn09ruuo+hPO2Cwt9wE2ED5TeWfAO7x2u8BHveWzwX+AxgwHPjea28JJHtfW3jLLfzu36/0+4/AWxc3eeMAAANbSURBVMDH3vNZwHhveSrwP97yjcBUb3k88I633Nvb9/WArt7vRG2/+/Ur/Z0O3OAt1wWah/I+pvw2qZuBBhX27zWhtp+BU4HBwJoKbQHbr8Ayb1vzvvecQ9bk9w+lCn7IJwDzKzyfDEz2u64A9m8O8BsgEWjntbUDEr3ll4HLK2yf6K2/HHi5QvvPtgumB+V3VFsIjAQ+9n6hdwERB+5jyu/NcIK3HOFtZwfu94rbBdsDaOYFoB3QHsr7eP/9slt6++1j4KxQ3M9A1AGhH5D96q1bX6H9Z9sd7BGKwzuHvPl6TeX9l3YQ8D0Q6Zzb6q3aBkR6ywfrf036uTwD3AWUec9bAVnOuRLvecXaf+qXtz7b274m9bcrsBP4tzek9YqZNSKE97FzLgN4EtgCbKV8v8UR2vt5v0Dt1w7e8oHtvyoUQz8kmVlj4D3gNudcTsV1rvzPfEjMvTWz0cAO51yc37VUowjKhwBecs4NAvIo/2//T0JpHwN449hjKP+D1x5oBJzta1E+8GO/hmLoh9zN182sDuWB/6Zz7n2vebuZtfPWtwN2eO0H639N+bmcBFxgZinATMqHeJ4FmpvZ/ju9Vaz9p35565sBu6k5/YXyI7R059z33vN3Kf8jEKr7GOAMYLNzbqdzrhh4n/J9H8r7eb9A7dcMb/nA9l8ViqEfUjdf9z6NfxVY55z7e4VVc4H9n+JPoHysf3/71d5MgOFAtvdfyfnAmWbWwjvKOtNrCyrOucnOuY7OuSjK990i59yVwGJgnLfZgf3d/3MY523vvPbx3qyPrkAPyj/0CjrOuW1Ampkd6zWNAhII0X3s2QIMN7OG3u/4/j6H7H6uICD71VuXY2bDvZ/h1RVe6+D8/pCjij44OZfyWS6bgHv9ruco+3Iy5f/9WwXEe49zKR/PXAhsBD4HWnrbG/Ci1/fVQEyF17oOSPIe1/rdt0r0/XT+f/ZON8r/MScBs4F6Xnt973mSt75bhe+/1/s5JFKJWQ0+93UgEOvt5w8pn6UR0vsYeBBYD6wBXqd8Bk5I7Wfgbco/syim/H901wdyvwIx3s9vE/ACB0wG+KWHLsMgIhJGQnF4R0REDkKhLyISRhT6IiJhRKEvIhJGFPoiImFEoS8iEkYU+iIiYeT/AH83qz4ZDMQ/AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MhILiPNCR75y", + "colab_type": "text" + }, + "source": [ + "## Implementing a new Multi-Armed Bandit by extending `MABAgent` base class" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Tu0BBHt43N4m", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 744 + }, + "outputId": "9864790d-d1e4-4d9e-cfc0-330875e19ebf" + }, + "source": [ + "from genrl.agents import MABAgent\n", + "from genrl.agents import EpsGreedyMABAgent, BernoulliMAB\n", + "\n", + "class ReinforcementComparison(MABAgent):\n", + " def __init__(self, bandit, alpha, beta):\n", + " super(ReinforcementComparison, self).__init__(bandit)\n", + " self.alpha = alpha\n", + " self.beta = beta\n", + " self._pi = torch.zeros(self._bandit.arms)\n", + " self._r = torch.zeros(self._bandit.arms)\n", + "\n", + " def select_action(self, context):\n", + " p = torch.softmax(self._pi, 0)\n", + " a = torch.distributions.Categorical(p).sample()\n", + " return a.item()\n", + "\n", + " def update_params(self, context, action, reward):\n", + " self._pi[action] += self.beta * (reward - torch.mean(self._r))\n", + " self._r[action] = self._r[action] * (1 - self.alpha) + self._r[action] * self.alpha\n", + "\n", + "bandit = BernoulliMAB(arms=50, context_type=\"int\")\n", + "agent = ReinforcementComparison(bandit, 0.3, 0.3)\n", + "trainer = MABTrainer(agent, bandit)\n", + "results = trainer.train(2000)\n", + "\n", + "plt.plot(results[\"cumulative_regrets\"])" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\n", + "Started at 25-08-20 18:26:48\n", + "Training ReinforcementComparison on BernoulliMAB for 2000 timesteps\n", + "timestep regret/regret reward/reward regret/cumulative_regret reward/cumulative_reward regret/regret_moving_avg reward/reward_moving_avg \n", + "100 0 1 44 56 0.44 0.56 \n", + "200 0 1 78 122 0.34 0.66 \n", + "300 0 1 111 189 0.33 0.67 \n", + "400 1 0 151 249 0.4 0.6 \n", + "500 0 1 184 316 0.33 0.67 \n", + "600 1 0 222 378 0.38 0.62 \n", + "700 1 0 264 436 0.42 0.58 \n", + "800 1 0 293 507 0.29 0.71 \n", + "900 0 1 330 570 0.37 0.63 \n", + "1000 0 1 367 633 0.37 0.63 \n", + "1100 1 0 405 695 0.38 0.62 \n", + "1200 0 1 440 760 0.35 0.65 \n", + "1300 1 0 471 829 0.31 0.69 \n", + "1400 1 0 509 891 0.38 0.62 \n", + "1500 0 1 541 959 0.32 0.68 \n", + "1600 1 0 571 1029 0.3 0.7 \n", + "1700 1 0 616 1084 0.45 0.55 \n", + "1800 0 1 655 1145 0.39 0.61 \n", + "1900 1 0 693 1207 0.38 0.62 \n", + "2000 1 0 727 1273 0.34 0.66 \n", + "Training completed in 0 seconds\n", + "Final Regret Moving Average: 0.34 | Final Reward Moving Average: 0.66\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[]" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 51 + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3dd5xU9fX/8dfZRu8siLQFKYoUgY2CIBqJBQtgQ40RsGEMGo2JxiQmmmbURCNGfxACGLCBYghEsQWxU6QXaUtnBZbelmXLnN8fc/G7EpBdmJ2ZnX0/H495zJ3P3Lv3cGf2zd3P/dx7zd0REZHEkhTrAkREJPIU7iIiCUjhLiKSgBTuIiIJSOEuIpKAUmJdAED9+vU9IyMj1mWIiJQrc+fO3e7u6Ud7Ly7CPSMjgzlz5sS6DBGRcsXM1h/rPXXLiIgkIIW7iEgCUriLiCQghbuISAJSuIuIJCCFu4hIAlK4i4gkIIW7iEgMhELOH9/6kqycfWXy8+PiJCYRkYrA3Vm/IxcHnp+excS5m2h7Sk1aNagR8XUp3EVEoiC/MMQvJy1m4txNX7dd1uEUru7cuEzWp3AXESljby76igcnLiI3v4hruzahZ6v6VE1LpvcZDUlKsjJZp8JdRKSM5OzN46//XcmrszdSJTWZR69sx43nNKNSSnKZr/u44W5mbYEJxZpaAr8BxgXtGcA6YIC77zIzA4YBlwG5wGB3nxfZskVE4tOS7D0889+VfPnVXr7akwdAm4bVeeWObtSvXilqdRw33N19BXAWgJklA9nAJOAhYJq7P25mDwWvfw70AVoHj3OA4cGziEhCW5K9h77PfUrI4ZwWdbng9AZccuYp9Gpdn/B+b/SUtlumN7Da3debWT/ggqB9LPAh4XDvB4xzdwdmmlltM2vk7psjVLOISNzJzS9k8AuzSU4yJt/Vgw5NasW0ntKG+w3Aq8F0w2KBvQVoGEw3BjYWW2ZT0KZwF5GE4+78/eM1vDhjPdv35zP21rNjHuxQinA3szSgL/CLI99zdzczL82KzWwIMASgWbNmpVlURCQuFBSFuPPFuXywPIe05CQevLQt57c56o2Roq40e+59gHnuvjV4vfVwd4uZNQJygvZsoGmx5ZoEbd/g7iOBkQCZmZml+o9BRCQePDZ1GR8sz6Fvp1MZdsNZUe9X/zalufzAjfxflwzAFGBQMD0ImFysfaCFdQP2qL9dRBJJXkERD7y+kBc+W8flHRvFXbBDCffczawacBFwZ7Hmx4HXzOw2YD0wIGifSngYZBbhoZC3RKxaEZEYc3eGvjyPactz6N6yHo9d1SHugh1KGO7ufgCod0TbDsKjZ46c14GhEalORCRO5BUUMeKj1cxas5MZa3Zw49lN+dPVHWNd1jHpDFURkeMID3P8gtlrd1KzcgrXZzblD/07xLqsb6VwFxE5iv2HChn7+Tq+2n2Ql2dtAGDwuRn85op2ZXY9mEhSuIuIHMHduWPsHGas2UFqstG4dhV+dkkb+p/VOC77149G4S4iEjiYX8TYGesY+fEadh7I5+7vtuJnl7SNdVknROEuIgJ8nrWdIS/OZf+hQiqlJPHAJW350QWnxbqsE6ZwF5EKraAoxIMTFzFpfjZJBr/v355rujSmalr5jsfyXb2IyEn4dNV27pswn+378zn3tHo8cU1HmtatGuuyIkLhLiIV0nMfrOIv760E4GcXt+HuC1vHuKLIUriLSIXz8qz1/OW9lbRqUJ2Xbz+HhjUrx7qkiFO4i0iF8taizfxq0hJOS6/Gm/f0pHJq2d/yLhZKc+EwEZFy7Z0lWxj6yjzqVUvjlTu6JWywg/bcRaQCmLt+F/e/toD1O3KpXTWVN3/cMyG7YopTuItIwvpwRQ4/e30h2/fnA+HLBwz9bivSa0TvRtWxonAXkYSTV1DE428v55+fryM12RjYvTmDzs3gtPTqsS4tahTuIpJQ9uYVcO+r85m+YhvtG9dk+E1dE2bsemko3EUkIWzdm8ePX53PrLU7Abg+sylPXBu/11svawp3ESnX3J35G3dz86hZHMgvovfpDbjh7GZc1K5hrEuLKYW7iJRbm3blMmDEDL7akwfAiB904dL2jWJcVXxQuItIuePuvLNkC7+YtJjduQXc3rMFg87NqJB968dS0htk1wZGAe0BB24FVgATgAxgHTDA3XdZ+Er2wwjfJDsXGOzu8yJeuYhUWE++u4LhH64G4NkbO9O306kxrij+lPQM1WHAO+5+OtAJWAY8BExz99bAtOA1QB+gdfAYAgyPaMUiUqHNXruTER+tpkuz2ix69GIF+zEcN9zNrBbQCxgN4O757r4b6AeMDWYbC/QPpvsB4zxsJlDbzNQJJiIn7d2lWxjw9xnUrJzKmMHfoWbl1FiXFLdKsufeAtgGvGBm881slJlVAxq6++Zgni3A4UPTjYGNxZbfFLR9g5kNMbM5ZjZn27ZtJ/4vEJGEtye3gKfeW8GdL86lbrU0JtzZjdpV02JdVlwrSZ97CtAFuMfdZ5nZMP6vCwYAd3cz89Ks2N1HAiMBMjMzS7WsiFQc/56fzX0TFgDQsGYl3rjrXJrU0YHT4ylJuG8CNrn7rOD1RMLhvtXMGrn75qDbJSd4PxtoWmz5JkGbiEiJHcwvYtQna3jq/ZXUqZrKb/u15+J2DRP6So6RdNxwd/ctZrbRzNq6+wqgN/Bl8BgEPB48Tw4WmQLcbWbjgXOAPcW6b0REjuuLdTu5YeRMikJO/eqVePOenpxSK7Gv4hhpJR3nfg/wspmlAWuAWwj3179mZrcB64EBwbxTCQ+DzCI8FPKWiFYsIglt485cBo6ejQFPXtOR/p0bk5aiW0+UVonC3d0XAJlHeav3UeZ1YOhJ1iUiFdCabfu58R8zOVhQxL9+dC5dmtWJdUnlls5QFZG4sGLLPi5/9hMKQ86fr+2oYD9JCncRialQyJk4bxMPTlxEkqE99ghRuItIzKzfcYCBY2azfkcuAOOHdFewR4jCXUSiZuHG3Wzdm8fBgiI+z9rBhDnh8x3v7NWSn1zURsMcI0jhLiJR8f8+zOLJd1Z8o61Vg+o8dlUHzm5RN0ZVJS6Fu4iUqTfmbuKxqcvYcSCflvWr8fT1Z5GSZDSoUYkGNTV2vawo3EWkTLg7Yz9fx6P/+ZKqacnc27s1d/RqSfVKip1o0FYWkTJx+9g5TFueQ71qaXzw0wuoVVVXcIwmhbuIRExufiFjPl3LlIVfsXLrfvq0P4WnB5xFlTQdKI02hbuInLR12w/w68lL+GTVdgDSUpK48eymPHZVB8I3Z5NoU7iLyAlbu/0AL85Yz5jP1gJwQdt0LuvQiOu6NlGox5jCXURKrSgUPlj6uze/BKBBjUqMGpRJxya1Y1yZHKZwF5FS2ZdXwI9enscnq7ZTs3IKYwZ/h8wMjVOPNwp3ESmR/YcKmb9hFw+8vogte/O45MyGDLuhs84qjVMKdxE5roUbd/P9f8zkQH4RAL/v356buzWPcVXybRTuIvKtsnL28YNRszhYUMTDl5/BBW3TadWgRqzLkuNQuIvI/wiFnGnLc/h89XbGzVhPUcgZP6Qb3VrWi3VpUkIKdxH5Wn5hiF9OWszEuZu+bkuvUYkRP+hK1+a6FG95UqJwN7N1wD6gCCh090wzqwtMADKAdcAAd99l4cGtwwjfRzUXGOzu8yJfuohE0qHCIm76xyzmrN9Fpya16N+5MVd1bkztqmmxLk1OQGn23L/r7tuLvX4ImObuj5vZQ8HrnwN9gNbB4xxgePAsInFqw45c+j7/KbtzCxjUvTm/7dc+1iXJSTqZW4r3A8YG02OB/sXax3nYTKC2mTU6ifWISBlauHE3vf48nd25BTx6ZTsFe4Ioabg78J6ZzTWzIUFbQ3ffHExvARoG042BjcWW3RS0fYOZDTGzOWY2Z9u2bSdQuoicrJVb93Hd32eQmmy8csc5DO7RItYlSYSUtFump7tnm1kD4H0zW178TXd3M/PSrNjdRwIjATIzM0u1rIicvFlrdnDrP78gvzDEm/f0pH3jWrEuSSKoROHu7tnBc46ZTQLOBraaWSN33xx0u+QEs2cDTYst3iRoE5E48fC/F/PSzA0AjPhBFwV7Ajput4yZVTOzGoengYuBJcAUYFAw2yBgcjA9BRhoYd2APcW6b0QkRg4VFrE7N5+n31/JSzM30KlpbeY+/D0uba9DYomoJHvuDYFJweU7U4BX3P0dM/sCeM3MbgPWAwOC+acSHgaZRXgo5C0Rr1pESmX+hl3cPHo2+w8VAtC9ZT1euv0ckpN0Wd5Eddxwd/c1QKejtO8Aeh+l3YGhEalORE5KKOS8+sUGHv73EpLN+OlFbahbPY2+nU5VsCc4naEqkoCWZO9h3oZdvLlwM7PX7SQtJYkJQ7rRuZnOMq0oFO4iCSQUcu5/bQH/XvDV1219O53Kn6/rSKUUXZq3IlG4i5Rzu3PzGfnxGpZt3sv8jbvZnVvAOS3q8vDl7cioX5UalVNjXaLEgMJdpBzbdSCfi5/5mG37DlGzcgot06vzvZ4NGPrdVrqHaQWncBcph9yd56dn8feP17Avr5CnB3Si/1mNSdJBUgko3EXKkZ0H8pk4dyMjPlrDzgP51KqSyuNXd+DqLk1iXZrEGYW7SDnxedZ2fvTKPHbnFlAlNZmffK8Nd1/YSkMa5agU7iJxLhRyfv/Wl7zw2ToAnh7Qics6NNKNqeVbKdxF4tjB/CJ+MmEB7yzdQofGtfj7zV05tXaVWJcl5YDCXSTO7Msr4NXZG1iwcTdTF28B4JIzGzL8pq46YColpnAXiRPuzrBpq3jmv6sAMIN2jWpy+3ktdMBUSk3hLhIH5m/YxX0TFrB+Ry7pNSrxiz6nc2WnU0lNPpmbpUlFpnAXiaHRn65lxEer2bbvEACDz83gkSvb6QQkOWkKd5EYmLVmB0++u4K563dRo3IKt/ZowR29WtColg6WSmQo3EWi6POs7Qz/aDWfrNoO6KJeUnYU7iJRUBRypi/P4fZxcwC48PQG/L5/exprWKOUEYW7SBnbujePa4Z/zqZdB6malsx/7unJaenVY12WJDiFu0gZWrRpN1f9v88pCjm39MjgrvNPo0HNyrEuSyqAEoe7mSUDc4Bsd7/CzFoA44F6wFzgZnfPN7NKwDigK7ADuN7d10W8cpE4t3FnLteNmEFRyHlh8Hf47ukNYl2SVCClGUR7L7Cs2OsngL+6eytgF3Bb0H4bsCto/2swn0iFsetAPn96exmXPvMxhwpDTB7aQ8EuUVeicDezJsDlwKjgtQEXAhODWcYC/YPpfsFrgvd7mwbtSgUx6pM1dPvTNP7+0RqqVkph+E1d6NS0dqzLkgqopN0yzwAPAjWC1/WA3e5eGLzeBDQOphsDGwHcvdDM9gTzby/+A81sCDAEoFmzZidav0jcGD97A394axlV05L5242dubLTqbEuSSqw4+65m9kVQI67z43kit19pLtnuntmenp6JH+0SNR9tfsgv3vzS848tSZLHr1EwS4xV5I99x5AXzO7DKgM1ASGAbXNLCXYe28CZAfzZwNNgU1mlgLUInxgVSTh5OYX8vs3l/Hq7A0APHltR125UeLCcffc3f0X7t7E3TOAG4AP3P0mYDpwbTDbIGByMD0leE3w/gfu7hGtWiQO5OzLo/dTH/Hq7A1k1KvKG3d158xTa8W6LBHg5Ma5/xwYb2Z/AOYDo4P20cCLZpYF7CT8H4JIQpm8IJv7X1tIUch5+PIzuK1nC13sS+JKqcLd3T8EPgym1wBnH2WePOC6CNQmEncKikK88NlaHpu6nLTkJMbc+h3Ob6NjRhJ/dIaqSAm4O7+ZvJQXZ64HoHbVVN67r5fONpW4pXAX+RbuzkuzNjDiw9Vk7z7Iea3r06d9I67LbKIbaUhcU7iLHEVRyJm6eDOPv72c7N0HSTIY0qslD1zSVqEu5YLCXeQI+/IKuG7EDJZv2QfA0O+exn3fa6NQl3JF4S4S2Lo3jz9NXcaUhV8R8vAt7+6/uA01K6fGujSRUlO4ixC+dMCvJy+hoMhp16gmd57fkn5nNT7+giJxSuEuFdqe3AKuGfE5WTn7ARg9KJPeZzSMcVUiJ0/hLhXWtGVb+fGr8zmQX8QtPTJ4+PJ2JOvSAZIgFO5S4Wzec5B7X13A7HU7AXj2xs701YW+JMEo3KVC2bznINcOn0H27oNc2elUfn35GToRSRKSwl0qhKKQM+qTNfzp7eUAPHFNB67/ju4jIIlL4S4J76vdB7luRHhvvUmdKjx+dUd6tq4f67JEypTCXRJWKOTMXLODeycsYNu+Q9zZqyX3fq81VdP0tZfEp2+5JKx7xs/nrUWbAXju+525oqMOmkrFoXCXhDR9eQ5vLdrMpWeewiN929GoVpVYlyQSVQp3STijPlnDH95aRuPaVXjmhrOonJoc65JEok7hLglh4cbdvPDZWj7N2sH2/YdIr1GJUYMyFexSYSncpVzL2ZvHjDU7+NnrCykoclqmV+Park24/6I2pKXoKo5ScR033M2sMvAxUCmYf6K7P2JmLYDxQD1gLnCzu+ebWSVgHNAV2AFc7+7ryqh+qaBCIefJd1cw4qPVAKQkGVN/fB7tTq0Z48pE4kNJ9twPARe6+34zSwU+NbO3gfuBv7r7eDMbAdwGDA+ed7l7KzO7AXgCuL6M6pcKyN0Z+so83l6yhZbp1Xjg4rZ0blaHU2rpTFORw477d6uH7Q9epgYPBy4EJgbtY4H+wXS/4DXB+71Nt4WXCAmFnEemLOXtJVu45MyGTLv/fPp0aKRgFzlCifrczSyZcNdLK+B5YDWw290Lg1k2AYcvft0Y2Ajg7oVmtodw1832CNYtFcz7X27lo5U5vDRzAwCdm9VmxA+6ov0GkaMrUbi7exFwlpnVBiYBp5/sis1sCDAEoFkzXeNDjq4o5Lw6ewMP/3sJAI1rV+H281owsHuGgl3kW5RqtIy77zaz6UB3oLaZpQR7702A7GC2bKApsMnMUoBahA+sHvmzRgIjATIzM/3E/wmSiCbO3cSLM9ezcss+DhYUcfopNXj59nOoWy1NoS5SAiUZLZMOFATBXgW4iPBB0unAtYRHzAwCJgeLTAlezwje/8DdFd5yXFk5+/nPwq9YvmUv05blkJaSRI9W9WnXqAYDz82gXvVKsS5RpNwoyZ57I2Bs0O+eBLzm7m+a2ZfAeDP7AzAfGB3MPxp40cyygJ3ADWVQtySIvXkFPPvfVWzclcu7S7cC4WGN57WuzxPXdqRBDR0oFTkRxw13d18EdD5K+xrg7KO05wHXRaQ6SVihkDPms7U89d5KDhYUUa9aGh2b1OKP/TtweqMapCbrBCSRk6EzVCXqpi/P4ZZ/fgFA9Uop/PX6TlzVuUmMqxJJLAp3iar3lm5hyItzqVk5hQcuact1mU11/ReRMqBwl6goCjm//c9Sxs1YT+XUJCbedS5tGtaIdVkiCUvhLmVub14BQ1+exyerttOpaW2e/35nmtSpGuuyRBKawl3KVFbOPi5/9lMOFYa4pksT/nJdR41TF4kChbuUmaKQ85MJCykKOX+9vhP9OjVWsItEicJdysRLM9d/fcmAR65sp9EwIlGmcJeIe3baKp5+fyWpycafru7ItV0V7CLRpnCXiDlUWMTz01fz7LRVtGlYnf/c05NKKRrmKBILCneJiKVf7WHAiBkcyC8io15V/j20h4JdJIYU7nJS3J1J87P55aTF5BWE+NVlZ3Bz9+Y6MUkkxhTuclLeXLSZ+19bSHKSMX5IN7q1rBfrkkQEhbucoOzdB3lp5nqGf7iajHpVeee+XtpbF4kjCncplbyCIsZ8tpa/vLuCkEPdammMuLmrgl0kzijcpcS+WLeTX/xrMVk5+6lZOYXnvt+FHq3qk5ykE5NE4o3CXY6rsCjEM/9dxXPTswC4s1dLfn7p6SQp1EXilsJdvtW7S7fwwOsL2ZtXSJM6VXjjrnNpWFN3RxKJdwp3Oaa/vr+SYdNWkZxk/OR7bbjrgtNIS9EdkkTKA4W7/I+vdh/kl5MW8+GKbbRMr8akH/WgVpXUWJclIqVw3HA3s6bAOKAh4MBIdx9mZnWBCUAGsA4Y4O67LHzZv2HAZUAuMNjd55VN+RJJS7L38OS7K/h45TYAzm+Tzj8GZmpvXaQcKsmeeyHwU3efZ2Y1gLlm9j4wGJjm7o+b2UPAQ8DPgT5A6+BxDjA8eJY49dwHqxjz2Tp2HsgHoHvLejx4aVs6N6sT48pE5EQdN9zdfTOwOZjeZ2bLgMZAP+CCYLaxwIeEw70fMM7dHZhpZrXNrFHwcySO7D9UyENvLOLNRZtpVKsyg8/NYGD35rRMrx7r0kTkJJWqz93MMoDOwCygYbHA3kK42wbCwb+x2GKbgrZvhLuZDQGGADRr1qyUZcvJmrlmBwNHzya/KMTZLery8u3nkJqs7heRRFHi32Yzqw68Adzn7nuLvxfspXtpVuzuI909090z09PTS7OonKTFm/Zww8iZOM6fru7AhCHdFOwiCaZEe+5mlko42F92938FzVsPd7eYWSMgJ2jPBpoWW7xJ0CZxICtnHwPHzKJKajLv3teLZvV0o2qRRFSS0TIGjAaWufvTxd6aAgwCHg+eJxdrv9vMxhM+kLpH/e2xVVgUYtSna/l89Y6vR8KMGpipYBdJYCXZc+8B3AwsNrMFQdsvCYf6a2Z2G7AeGBC8N5XwMMgswkMhb4loxVIqBw4Vcv3IGSzJ3kvl1CS+k1GH3/dvz+mn1Ix1aSJShkoyWuZT4FgXEel9lPkdGHqSdUkELN+yl75/+4z8ohC39MjgN1e0I/yHmIgkOp2hmoCWbd7Le0u3MmzaSkIOT13XiWt0k2qRCkXhnkCKQs4Dry/kX/PDx6+rpiUzfkg3OjapHePKRCTaFO4JYMuePMbOWMeELzay80A+nZrU4vFrOpJRrxpV0nQTDZGKSOFezuXszaPPsI/ZlVtAtbRkftHndO48/7RYlyUiMaZwL6eKQs6stTu4efRsikLOyJu7clG7hjpgKiKAwr1c+vKrvdw0aia7cgswg+E3deHiM0+JdVkiEkcU7uWIu/OPT9bwxDsrKAo591zYipu7N6dBDd0ZSUS+SeFeTuzNK+Cul+byWdYOalZOYfgPutKjVf1YlyUicUrhXg7MWL2D74+aiTtc3rERz93YWX3rIvKtFO5x7IPlW3lkylI27jxIksHIgZlc1K7h8RcUkQpP4R6HCotCPDZ1OWM+WwvAzd2ac0uPDN1EQ0RKTOEeh+58cS7TlufQqkF1xg/pRv3qlWJdkoiUMwr3OPP24s1MW57DjWc35bGrOqhvXUROiMI9Try7dAtPv7eSFVv30bRuFR7te6aCXUROmMI9hlZs2cc7S7awcNNuPlieQ3KScVXnxtzbuzWVUnRNGBE5cQr3GNm6N4/+z3/GwYIiUpKMc0+rx99u7Ew99a+LSAQo3GMgFHKGvjyPQ4VFTLm7hy7JKyIRp3CPkqKQs2DjLj5auZ1/fraWvXmF3H9RGwW7iJSJktwgewxwBZDj7u2DtrrABCADWAcMcPddwc20hxG+h2ouMNjd55VN6fHP3Xl7yRYWZ+9h9CdryS8KAVApJYlfXXYGt/VsEeMKRSRRlWTP/Z/Ac8C4Ym0PAdPc/XEzeyh4/XOgD9A6eJwDDA+eK5TCohDDP1zN2Bnr2b7/EAB1q6Vxa48MerSqz1lNa2skjIiUqZLcIPtjM8s4orkfcEEwPRb4kHC49wPGBTfJnmlmtc2skbtvjlTB5cFrczbx1PsraVKnCkN6teTOXi2pUzWNpCQFuohEx4n2uTcsFthbgMMXPGkMbCw236ag7X/C3cyGAEMAmjVrdoJlxJ99eQU8/f4Kzs6oy4Q7u2kPXURiIulkf0Cwl+4nsNxId89098z09PSTLSNu/OHNZWzfn8+vLj9DwS4iMXOie+5bD3e3mFkjICdozwaaFpuvSdCW0DbtymX87I18vGobizbt4bquTejUVKNgRCR2TjTcpwCDgMeD58nF2u82s/GED6TuSeT+9lDIefaDVTzz31UA1KiUwtVdGvO7fu1jXJmIVHQlGQr5KuGDp/XNbBPwCOFQf83MbgPWAwOC2acSHgaZRXgo5C1lUHNcyMrZzx3j5rB2+wEa1KjEUwM6cV7rxOleEpHyrSSjZW48xlu9jzKvA0NPtqh4t3rbfq7826ccLCjih+efxkN9To91SSIi36AzVEvpwxU5DH7hCwAm/rA7mRl1Y1yRiMj/UriX0OJNexjy4hw278kjOcmYMKSbgl1E4pbC/Tg+z9rO3z7IYvmWvezKLWBg9+bccV5LmtatGuvSRESOSeF+FO7O0++vZPiHqykMhYfwn3taPX56cVu6Nq8T4+pERI5P4X6Ef8/P5teTl7Avr5BTa1Xmpm7NuemcZtSumhbr0kRESkzhTnhP/T+LNvP6nI18smo7AA9c0pYhvVqSmnzSJ/GKiERdhQ/3w8Mac/OLADi/TTrP3tiZWlVSY1yZiMiJq9Dh/nnWdm4aPQt3uP+iNvzw/NNIS9GeuoiUfxU23Odv2MXAMbNJNmP8nRrWKCKJpcKFe35hiJ9MWMBbizeTZDDl7p60b1wr1mWJiERUhQr3XQfyuXXsF8zfsJvuLevx+DUdaF6vWqzLEhGJuAoT7tNX5PDolKWs35HLHee14JeX6XrrIpK4EjrcD+YXMWzaKl6euZ59hwoB+PUV7XRjahFJeAkb7oVFIfo9/ykrt+4nvUYlburWnB+e31InI4lIhZCw4f7gG4tYuXU/91/UhnsubKUuGBGpUBIu3PfmFfC7/3zJv+ZlMyCzCT/u3TrWJYmIRF1Chfu+vAIuf/YTNu48SGbzOvy2r253JyIVU8KE+z8+XsMfpy4DdNBURKRMzrU3s0vNbIWZZZnZQ2WxjuKen57FH6cuo3m9qgy/qYuCXUQqvIjvuZtZMvA8cBGwCfjCzKa4+5eRXhfArDU7+PO7K2jXqCYT7+pO1bSE+WNEROSElcWe+9lAlruvcfd8YDzQrwzWw2tfbOT6kTOpWTmFV+44R8EuIhIoi3BvDGws9npT0PYNZjbEzOaY2a5JFeEAAAXJSURBVJxt27ad0IpqV03lqs6Nee2H3TV+XUSkmJjt6rr7SGAkQGZmpp/Iz7j4zFO4+MxTIlqXiEgiKIs992ygabHXTYI2ERGJkrII9y+A1mbWwszSgBuAKWWwHhEROYaId8u4e6GZ3Q28CyQDY9x9aaTXIyIix1Ymfe7uPhWYWhY/W0REjk83DBURSUAKdxGRBKRwFxFJQAp3EZEEZO4ndP5QZIsw2wasP8HF6wPbI1hOpKiu0onXuiB+a1NdpZOIdTV39/SjvREX4X4yzGyOu2fGuo4jqa7Side6IH5rU12lU9HqUreMiEgCUriLiCSgRAj3kbEu4BhUV+nEa10Qv7WprtKpUHWV+z53ERH5X4mw5y4iIkdQuIuIJKByHe7RvhH3EetuambTzexLM1tqZvcG7Y+aWbaZLQgelxVb5hdBrSvM7JIyrG2dmS0O1j8naKtrZu+b2arguU7Qbmb2bFDXIjPrUkY1tS22TRaY2V4zuy8W28vMxphZjpktKdZW6u1jZoOC+VeZ2aAyquvPZrY8WPckM6sdtGeY2cFi221EsWW6Bp9/VlC7lUFdpf7cIv37eoy6JhSraZ2ZLQjao7m9jpUN0f2OuXu5fBC+nPBqoCWQBiwE2kVx/Y2ALsF0DWAl0A54FPjZUeZvF9RYCWgR1J5cRrWtA+of0fYk8FAw/RDwRDB9GfA2YEA3YFaUPrstQPNYbC+gF9AFWHKi2weoC6wJnusE03XKoK6LgZRg+olidWUUn++InzM7qNWC2vuUQV2l+tzK4vf1aHUd8f5TwG9isL2OlQ1R/Y6V5z33qN2I+2jcfbO7zwum9wHLOMq9YovpB4x390PuvhbIIvxviJZ+wNhgeizQv1j7OA+bCdQ2s0ZlXEtvYLW7f9tZyWW2vdz9Y2DnUdZXmu1zCfC+u+90913A+8Clka7L3d9z98Lg5UzCdzY7pqC2mu4+08MJMa7YvyVidX2LY31uEf99/ba6gr3vAcCr3/Yzymh7HSsbovodK8/hXqIbcUeDmWUAnYFZQdPdwZ9XYw7/6UV063XgPTOba2ZDgraG7r45mN4CNIxBXYfdwDd/6WK9vaD02ycW2+1Wwnt4h7Uws/lm9pGZnRe0NQ5qiUZdpfncor29zgO2uvuqYm1R315HZENUv2PlOdzjgplVB94A7nP3vcBw4DTgLGAz4T8No62nu3cB+gBDzaxX8TeDPZSYjIG18K0X+wKvB03xsL2+IZbb51jM7FdAIfBy0LQZaObunYH7gVfMrGYUS4q7z+0IN/LNHYiob6+jZMPXovEdK8/hHvMbcZtZKuEP72V3/xeAu2919yJ3DwH/4P+6EqJWr7tnB885wKSghq2Hu1uC55xo1xXoA8xz961BjTHfXoHSbp+o1Wdmg4ErgJuCUCDo9tgRTM8l3J/dJqiheNdNmdR1Ap9bNLdXCnA1MKFYvVHdXkfLBqL8HSvP4R7TG3EHfXqjgWXu/nSx9uL91VcBh4/kTwFuMLNKZtYCaE34QE6k66pmZjUOTxM+ILckWP/ho+2DgMnF6hoYHLHvBuwp9qdjWfjGHlWst1cxpd0+7wIXm1mdoEvi4qAtoszsUuBBoK+75xZrTzez5GC6JeHtsyaoba+ZdQu+owOL/VsiWVdpP7do/r5+D1ju7l93t0Rzex0rG4j2d+xkjgrH+kH4KPNKwv8L/yrK6+5J+M+qRcCC4HEZ8CKwOGifAjQqtsyvglpXcJJH5L+lrpaERyIsBJYe3i5APWAasAr4L1A3aDfg+aCuxUBmGW6zasAOoFaxtqhvL8L/uWwGCgj3Y952ItuHcB94VvC4pYzqyiLc73r4OzYimPea4PNdAMwDriz2czIJh+1q4DmCM9EjXFepP7dI/74era6g/Z/AD4+YN5rb61jZENXvmC4/ICKSgMpzt4yIiByDwl1EJAEp3EVEEpDCXUQkASncRUQSkMJdRCQBKdxFRBLQ/wft6WGUWOHgsAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sL6RsMtSSFRS", + "colab_type": "text" + }, + "source": [ + "## Implementing a new agent by extending the `DCBAgent` base class" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "r4ld3UXj31ne", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 623 + }, + "outputId": "4caa035b-7b89-48ac-f689-4ca8c6d7cf3e" + }, + "source": [ + "from genrl.agents import DCBAgent, BernoulliMAB\n", + "from genrl.trainers import DCBTrainer\n", + "from genrl.agents.bandits.contextual.common import NeuralBanditModel, TransitionDB\n", + "\n", + "class NeuralAgent(DCBAgent):\n", + " def __init__(self, bandit, **kwargs):\n", + " super(NeuralAgent, self).__init__(bandit, **kwargs)\n", + " self.model = (\n", + " NeuralBanditModel(\n", + " context_dim=self.context_dim,\n", + " n_actions=self.n_actions,\n", + " hidden_dims=kwargs.get(\"hidden_dims\", [64]),\n", + " **kwargs\n", + " )\n", + " .to(torch.float)\n", + " .to(self.device)\n", + " )\n", + " self.db = TransitionDB(self.device)\n", + " self.t = 0\n", + "\n", + " def select_action(self, context):\n", + " self.t += 1\n", + " if self.t < self.n_actions * self.init_pulls:\n", + " return torch.tensor(\n", + " self.t % self.n_actions, device=self.device, dtype=torch.int\n", + " ).view(1)\n", + "\n", + " results = self.model(context)\n", + " action = torch.argmax(results[\"pred_rewards\"]).to(torch.int).view(1)\n", + " return action\n", + "\n", + " def update_db(self, context, action, reward):\n", + " self.db.add(context, action, reward)\n", + "\n", + " def update_param(self, action, batch_size=512, train_epochs=20):\n", + " self.model.train_model(self.db, train_epochs, batch_size)\n", + "\n", + "bandit = BernoulliMAB(arms=10, context_type=\"tensor\")\n", + "agent = NeuralAgent(bandit)\n", + "trainer = DCBTrainer(agent, bandit)\n", + "results = trainer.train(5000)\n", + "\n", + "plt.plot(results[\"cumulative_regrets\"])" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\n", + "Started at 25-08-20 18:28:18\n", + "Training NeuralAgent on BernoulliMAB for 5000 timesteps\n", + "timestep regret/regret reward/reward regret/cumulative_regret reward/cumulative_reward regret/regret_moving_avg reward/reward_moving_avg \n", + "100 0 1 14 86 0.14 0.86 \n", + "200 0 1 15 185 0.075 0.925 \n", + "300 0 1 18 282 0.016 0.984 \n", + "400 0 1 20 380 0.02 0.98 \n", + "500 0 1 23 477 0.028 0.972 \n", + "\n", + "Encounterred exception during training!\n", + "'NeuralAgent' object has no attribute 'update_params'\n", + "\n", + "Training completed in 0 seconds\n", + "Final Regret Moving Average: 0.028 | Final Reward Moving Average: 0.972\n" + ], + "name": "stdout" + }, + { + "output_type": "stream", + "text": [ + "Traceback (most recent call last):\n", + " File \"/content/genrl/genrl/trainers/bandit.py\", line 194, in train\n", + " self.agent.update_params(\n", + "AttributeError: 'NeuralAgent' object has no attribute 'update_params'\n" + ], + "name": "stderr" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[]" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 52 + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD7CAYAAABzGc+QAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAVHUlEQVR4nO3de3Bc5XnH8d9jyRf5giRHslF8QRhcAgQwRHWd4DTO3XE7DcykaQgDnsQzTidJS2aYdEjStOm0nUmbW9MppXUKA51S0mYSCiU0hDikSdqMiUxMLGMcDDEFR5YvaNeXXUkr6ekfe2QLI6PLnt2z5z3fz4xGu2ePdp93svzy+j3v+x5zdwEA0mdW0gUAAGaGAAeAlCLAASClCHAASCkCHABSigAHgJSaNMDNbIWZPWZmT5nZHjO7JTr+OTM7aGa7op9N1S8XADDGJpsHbmYdkjrc/QkzWyRpp6TrJL1f0kl3/2L1ywQAnK1xshPcvVdSb/T4hJntlbRsJh/W1tbmnZ2dM/lTAMisnTt3HnX39rOPTxrg45lZp6SrJe2QdK2kj5vZzZK6Jd3q7v2v9vednZ3q7u6ezkcCQOaZ2fMTHZ/yRUwzWyjpm5I+4e7HJd0h6SJJa1TuoX/pHH+31cy6zaz7yJEj0y4cADCxKQW4mc1WObzvdfdvSZK797n7iLuPSvqapLUT/a27b3P3Lnfvam9/xb8AAAAzNJVZKCbpTkl73f3L4453jDvtekk98ZcHADiXqYyBXyvpJkm7zWxXdOzTkm4wszWSXNIBSR+pSoUAgAlNZRbKjyXZBC89HH85AICpYiUmAKQUAQ4AKTWteeAAgFd3cnBY9/zvAQ2WRl52/PprluvCtgWxfhYBDgAx+sG+w/rCI/skSTbu6uE1F7QS4ABQz146NSRJ6v7jd6ht4dyqfhZj4AAQo1yhJElqbppd9c8iwAEgRrlCSQvnNmp2Q/XjlQAHgBjlCkM16X1LBDgAxCpXLKllPgEOAKmTKwzVLMCZhQIg0w7minruyMnY3q/v+KDWrGyJ7f1eDQEOINO23P1TPX3oRKzvuemK82N9v3MhwAFk2qHjA3rP68/XlvUXxvJ+ZtLlr22O5b0mQ4ADyKzRUVe+WNLqJQvV1bk46XKmjYuYADLrxMCw3KXm+XOSLmVGCHAAmdVfKC97b6nRvO24EeAAMitXLC97r9W0v7gR4AAyKzfWA0/pEAoXMQFkysioqzQyKkk6cmJQUnp74AQ4gEzZ9NUfaV/fy+d9L6YHDgD1rTQyqn19J/Tm1W1600VtkqTXtsxT6wICHADqWj66aPmOS5dq85s6ky0mBlzEBJAZYzdbSOuY99kIcACZkS+me9bJ2QhwAJlxugee0oU7ZyPAAWRGP0MoAJBOaV+4czZmoQCoe//8kwP6RveLFb9P3/EBzTJp0dwwoi+MVgAI2oO7fqWDuaLWrKjsTjfti+bq0o5FmjXLYqosWQQ4gLqXK5a0btVi/f2Nb0i6lLrCGDiAupcrlNTcFMa4dZwIcAB1zd2VL9buTu9pMmmAm9kKM3vMzJ4ysz1mdkt0fLGZPWpmz0S/W6tfLoCsKQyNqDTiwczdjtNUeuDDkm5198skrZP0MTO7TNJtkra7+2pJ26PnABCrsZsutAYy9S9Okwa4u/e6+xPR4xOS9kpaJum9ku6JTrtH0nXVKhJAdvWfKs/dbmYI5RWmNQvFzDolXS1ph6Sl7t4bvXRI0tJYKwNQ1x7/5Uv6r57eyU+sUN/xAUnhLH+P05QD3MwWSvqmpE+4+3GzM/Mo3d3NzM/xd1slbZWklStXVlYtgLrxd4/t1//sP6r5cxqq/lnLWpq0qn1h1T8nbaYU4GY2W+XwvtfdvxUd7jOzDnfvNbMOSYcn+lt33yZpmyR1dXVNGPIA0idXGNKbV7fp7g+tTbqUzJrKLBSTdKekve7+5XEvPShpc/R4s6QH4i8PQL3qLwwxrJGwqfTAr5V0k6TdZrYrOvZpSZ+X9O9mtkXS85LeX50SAdSjXKEUzKZQaTVpgLv7jyWda+OAt8dbDoA0GB4Z1YmBYRbXJIyVmACm7fjAsCRmhiSNAAcwbaHtq51W7EYI1Dl314+eOapTg8NJl3LagWMFSSyuSRoBDtS5J1/M6+a7Hk+6jAmtaG1KuoRMI8CBOnc4Wol4+wev0UVLFiRczRkL5zZqeev8pMvINAIcqHNjmzldubxZKxYTmDiDi5hAnctHd1JvXcAFQ7wcAQ7Uuf7CkBpnmRbUYM8RpAsBDtS5XLGklvmzNX4DOUAiwIG6ly+U1MyCGUyAi5hADEZHXceiGw/E7ciJQRbMYEIEOBCDzz7Qo3t3/F/V3v9dl3G/FLwSAQ7EYP/hk1rVtkAfWn9hVd5//cVtVXlfpBsBDsQgXyzp4iULddO6C5IuBRnCRUwgBv2FIbZWRc0R4EAMcoWSWrnQiBojwIEKDZRGNDg8ys58qDkCHKhQLlrq3tJEDxy1RYADFeo/fXMDeuCoLWahINM++x89+vH+oxW9x0BpRBK3F0PtEeDItId39+q8ptm6YllzRe+zYG6j1qxsiakqYGoIcGSWuytXLOkDa1fok+9+XdLlANPGGDgy6+TgsEZGnYuPSC0CHJk1NnuE6X9IKwIcmXVm+h8BjnQiwJFZuWJ5+h+3KkNaEeDILHrgSDtmoaDufe+pPj2y51Ds7/vc0VOSGANHehHgqHv/8N/P6ucH82qrwlDH2s7FWswmVEgpAhx1r78wpHdeulS333hN0qUAdYUxcNS9fLHEMAcwAQIcdc3dlSuUuNAITIAAR107OTis4VFnpz9gApMGuJndZWaHzaxn3LHPmdlBM9sV/WyqbpnIKvbaBs5tKj3wuyVtnOD4V9x9TfTzcLxlAWX5IsvdgXOZdBaKu//QzDqrXwpq4ckXctrXdyLpMqbs2cMnJYn7TQITqGQa4cfN7GZJ3ZJudff+iU4ys62StkrSypUrK/g4xOH3/2WnevMDSZcxLQ2zTMtbm5IuA6g7Mw3wOyT9uSSPfn9J0ocnOtHdt0naJkldXV0+w89DDNxdR08O6qZ1F+gjb1mVdDlTtnBuo1rogQOvMKMAd/e+scdm9jVJD8VWEarm1NCISiOu5a1NWt46P+lyAFRoRtMIzaxj3NPrJfWc61zUj1x0813Gk4EwTNoDN7P7JG2Q1GZmL0r6U0kbzGyNykMoByR9pIo1IibcwAAIy1RmodwwweE7q1ALqoztU4GwsBIzQ8ZuYMAFQSAMmdiNcHB4RL/KpWvqXDU8d6S8/zXL0oEwZCLA//C+n+mRPX2Tn5gBjbNMzQyhAEHIRIA/f6ygK5c368PXXph0KYlb1tqkebMbki4DQAwyEeD5YknrL27TdVcvS7oUAIhNJi5i5golxn0BBCf4AB8ojahYGmHmBYDgBB/gY9uR0gMHEJrgA5wbAgAIVQYCfGzxCj1wAGEJP8DH7ujC3GcAgQk/wMd24FvAEAqAsGQgwNnACUCYwg/wYkmzG0zz57D6EEBYwg/wQknNTXNkZkmXAgCxCj7A88UhZqAACFLwAd5/qqRWAhxAgIIP8FyxPIQCAKEJOsB7Dua1t/c4QygAghR0gP/nk7+SJL3tdUsSrgQA4hd0gPcXhrT0vLnadEVH0qUAQOyCDvBcocQmVgCCFXaAF7mRA4BwhR3gBeaAAwhX4AHOEAqAcAUb4O7OEAqAoAV5V/qh4VF9Y+cLGhoe5V6YAIIVZA98xy+P6TP390iSLl6yMOFqAKA6guyBHztZvonDQ3+wXq9f1pxwNQBQHUH2wMfuwvPalqaEKwGA6gkywPuju/CcNy/If2AAgKRAAzxfLGnRvEY1NgTZPACQNIUAN7O7zOywmfWMO7bYzB41s2ei363VLXN6WMADIAum0kW9W9LGs47dJmm7u6+WtD16XjdyRRbwAAjfpAHu7j+U9NJZh98r6Z7o8T2Srou5rorkCizgARC+mQ4SL3X33ujxIUlLY6onFuUhFHrgAMJW8VU+d3dJfq7XzWyrmXWbWfeRI0cq/bgpKQ+h0AMHELaZBnifmXVIUvT78LlOdPdt7t7l7l3t7e0z/LipGx115dkDBUAGzDTAH5S0OXq8WdID8ZRTuRMDw3IXQygAgjeVaYT3SfqJpEvM7EUz2yLp85LeaWbPSHpH9Lwu5IrlVZgMoQAI3aRLFd39hnO89PaYa4lFLlqFyRAKgNAFt1SxP9oHhQAHELrgAjxfHOuBMwYOIGzBBfjpIRTGwAEELtgAbybAAQQuvAAvDmnRXHYiBBC+4FIuVyipZQG9bwDhCzDAh9iJEEAmhBfgLKMHkBHBBXi+UOICJoBMCC7Ac8WSWpkDDiADggnwU4PD+ti9T6if26kByIhgAvyp3uP69u5eXbJ0kTZcUv1tawEgaZNuZpUW/afKe6B88Xev0uuXNSdcDQBUXzA98FyRFZgAsiWYAM+zjSyAjAkmwPsLQ2qYZVo4N5hRIQB4VcEE+NiNjM0s6VIAoCaCCfB8oaRmhk8AZEgQAZ4vlvTt3b0s4AGQKUEE+Hd6eiVJFyyen3AlAFA7QQT4S6fKM1D+8vorEq4EAGoniADPFYc0p3GW5s0OojkAMCVBJF7uFDNQAGRPGAFeHOICJoDMCSPAmUIIIIOCCPB8tIgHALIk9QE+UBrR04dOsAcKgMxJfYDf/th+SdL5zU0JVwIAtZX6AO/ND0iSPrrhooQrAYDaSn2A5wolXdpxnubNbki6FACoqdQHeL44xAVMAJmU+gDPFUpcwASQSRXd/cDMDkg6IWlE0rC7d8VR1HTkiiW1sIgHQAbFcfuat7r70RjeZ9rcXbnCED1wAJmU6iGUk4PDKo04Y+AAMqnSAHdJ3zWznWa2NY6CpqrnYF5v+IvvSZJaFzCEAiB7Kh1CWe/uB81siaRHzexpd//h+BOiYN8qSStXrqzw4874Rd8JDQ2P6qMbLtK7Lz8/tvcFgLSoqAfu7gej34cl3S9p7QTnbHP3Lnfvam9vr+TjXiZXKN/EYetvrlIzQygAMmjGAW5mC8xs0dhjSe+S1BNXYZPJFUsykxbNI7wBZFMlQyhLJd0f3UShUdK/uvt3YqlqCvKFIZ03b7YaZnETBwDZNOMAd/fnJF0VYy3T0l8oqZXpgwAyLLXTCHPFkppZwAMgw1Ib4PkCe6AAyLbUBnh5CT0BDiC70hvghRI3MgaQaakM8JFR1/GBEvO/AWRaKgP8eLEkdzGEAiDTUhnguWJ5FSYBDiDL0hnghSFJYh9wAJmW0gCPeuCMgQPIsHQGeJEeOACkM8DpgQNAegPcTDqPAAeQYSkNcHYiBIB0BjjL6AEgpQFeKDH+DSDz0hngbCULACkN8MIQN3MAkHkpDXCGUACgknti1py76yfPHivvRMgQCoCMS1WA73ohpw/+0w5J0vKWpoSrAYBkpSrA+44PSpLuuPEavfvy8xOuBgCSlaox8Hy0B8qVK1o0i0U8ADIuVQHOHigAcEa6ArxY0uwG0/w5DUmXAgCJS1eAF0pqmT9HZgyfAEDKAnyI4RMAiKQmwAeHR3Ts5BCbWAFAJBUB7u7a8IUf6PEDL2nxAhbwAICUknngJwaH1Zsf0MbLz9cnN16SdDkAUBdS0QPPR9MH337pEl3UvjDhagCgPqQiwE/P/2b/EwA4LR0Bfvou9FzABIAxFQW4mW00s31mtt/MbourqLP1Rz1w9gAHgDNmHOBm1iDpdknvkXSZpBvM7LK4ChsvXyj3wJubGEIBgDGV9MDXStrv7s+5+5Ckr0t6bzxlvdzYGHgzi3gA4LRKAnyZpBfGPX8xOha7XLGkBXMaNKcxFUP2AFATVU9EM9tqZt1m1n3kyJEZvcevLV2o37qyI+bKACDdKgnwg5JWjHu+PDr2Mu6+zd273L2rvb19Rh/0e7++Un/9vqtmViUABKqSAP+ppNVmdqGZzZH0AUkPxlMWAGAyM15K7+7DZvZxSY9IapB0l7vvia0yAMCrqmgvFHd/WNLDMdUCAJgGpnUAQEoR4ACQUgQ4AKQUAQ4AKUWAA0BKmbvX7sPMjkh6foZ/3ibpaIzl1LOstDUr7ZRoa4hq2c4L3P0VKyFrGuCVMLNud+9Kuo5ayEpbs9JOibaGqB7ayRAKAKQUAQ4AKZWmAN+WdAE1lJW2ZqWdEm0NUeLtTM0YOADg5dLUAwcAjJOKAK/VzZNrwczuMrPDZtYz7thiM3vUzJ6JfrdGx83M/jZq98/N7JrkKp8+M1thZo+Z2VNmtsfMbomOB9VeM5tnZo+b2ZNRO/8sOn6hme2I2vNv0bbLMrO50fP90eudSdY/E2bWYGY/M7OHoudBttXMDpjZbjPbZWbd0bG6+f7WfYDX8ubJNXK3pI1nHbtN0nZ3Xy1pe/RcKrd5dfSzVdIdNaoxLsOSbnX3yyStk/Sx6H+70No7KOlt7n6VpDWSNprZOkl/Jekr7n6xpH5JW6Lzt0jqj45/JTovbW6RtHfc85Db+lZ3XzNuymD9fH/dva5/JL1R0iPjnn9K0qeSrqvCNnVK6hn3fJ+kjuhxh6R90eN/lHTDROel8UfSA5LeGXJ7Jc2X9ISk31B5kUdjdPz091jlPfTfGD1ujM6zpGufRhuXqxxcb5P0kCQLuK0HJLWddaxuvr913wNXDW+enKCl7t4bPT4kaWn0OJi2R/90vlrSDgXY3mhIYZekw5IelfSspJy7D0enjG/L6XZGr+clvaa2FVfkbyT9kaTR6PlrFG5bXdJ3zWynmW2NjtXN97eiGzogfu7uZhbU1CAzWyjpm5I+4e7Hzez0a6G0191HJK0xsxZJ90t6XcIlVYWZ/bakw+6+08w2JF1PDax394NmtkTSo2b29PgXk/7+pqEHPqWbJ6dcn5l1SFL0+3B0PPVtN7PZKof3ve7+rehwsO1195ykx1QeRmgxs7FO0vi2nG5n9HqzpGM1LnWmrpX0O2Z2QNLXVR5G+arCbKvc/WD0+7DK/8e8VnX0/U1DgGfh5skPStocPd6s8ljx2PGbo6vb6yTlx/3Tre5Zuat9p6S97v7lcS8F1V4za4963jKzJpXH+feqHOTvi047u51j7X+fpO97NGha79z9U+6+3N07Vf5v8fvufqMCbKuZLTCzRWOPJb1LUo/q6fub9EWCKV5I2CTpFyqPK34m6XoqbMt9knollVQeI9ui8pjgdknPSPqepMXRuabyDJxnJe2W1JV0/dNs63qVxxB/LmlX9LMptPZKulLSz6J29kj6k+j4KkmPS9ov6RuS5kbH50XP90evr0q6DTNs9wZJD4Xa1qhNT0Y/e8ayp56+v6zEBICUSsMQCgBgAgQ4AKQUAQ4AKUWAA0BKEeAAkFIEOACkFAEOAClFgANASv0/hslqB6Oo/aAAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4e8fq4n24LhA", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/examples/DQN_demo.ipynb b/examples/DQN_demo.ipynb new file mode 100644 index 00000000..4b752143 --- /dev/null +++ b/examples/DQN_demo.ipynb @@ -0,0 +1,369 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "GenRL_DQN_Video.ipynb", + "provenance": [], + "collapsed_sections": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BOkJToH0asLB", + "colab_type": "text" + }, + "source": [ + "# Examples with DQN from GenRL" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7SATMNIdaxZi", + "colab_type": "text" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "f7AIp66ky69_", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "3a3f18e4-572f-4069-80de-ead359bf240e" + }, + "source": [ + "!git clone https://github.com/SforAiDl/genrl\n", + "!pip install -e genrl" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "fatal: destination path 'genrl' already exists and is not an empty directory.\n", + "Obtaining file:///content/genrl\n", + "Requirement already satisfied: atari-py==0.2.6 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (0.2.6)\n", + "Requirement already satisfied: box2d-py==2.3.8 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (2.3.8)\n", + "Requirement already satisfied: certifi==2019.11.28 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (2019.11.28)\n", + "Requirement already satisfied: cloudpickle==1.3.0 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.3.0)\n", + "Requirement already satisfied: future==0.18.2 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (0.18.2)\n", + "Requirement already satisfied: gym==0.17.1 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (0.17.1)\n", + "Requirement already satisfied: numpy==1.18.2 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.18.2)\n", + "Requirement already satisfied: opencv-python==4.2.0.34 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (4.2.0.34)\n", + "Requirement already satisfied: pandas==1.0.4 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.0.4)\n", + "Requirement already satisfied: Pillow==7.1.0 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (7.1.0)\n", + "Requirement already satisfied: pyglet==1.5.0 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.5.0)\n", + "Requirement already satisfied: scipy==1.4.1 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.4.1)\n", + "Requirement already satisfied: six==1.14.0 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.14.0)\n", + "Requirement already satisfied: matplotlib==3.2.1 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (3.2.1)\n", + "Requirement already satisfied: pytest==5.4.1 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (5.4.1)\n", + "Requirement already satisfied: torch==1.4.0 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.4.0)\n", + "Requirement already satisfied: torchvision==0.5.0 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (0.5.0)\n", + "Requirement already satisfied: tensorflow-tensorboard==1.5.1 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.5.1)\n", + "Requirement already satisfied: tensorboard==1.15.0 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.15.0)\n", + "Requirement already satisfied: pre-commit==2.4.0 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (2.4.0)\n", + "Requirement already satisfied: importlib-resources==1.0.1 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (1.0.1)\n", + "Requirement already satisfied: setuptools==41.0.0 in /usr/local/lib/python3.6/dist-packages (from genrl==0.0.1) (41.0.0)\n", + "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas==1.0.4->genrl==0.0.1) (2018.9)\n", + "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas==1.0.4->genrl==0.0.1) (2.8.1)\n", + "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib==3.2.1->genrl==0.0.1) (2.4.7)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib==3.2.1->genrl==0.0.1) (0.10.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib==3.2.1->genrl==0.0.1) (1.2.0)\n", + "Requirement already satisfied: importlib-metadata>=0.12; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (1.7.0)\n", + "Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (0.2.5)\n", + "Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (1.9.0)\n", + "Requirement already satisfied: more-itertools>=4.0.0 in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (8.4.0)\n", + "Requirement already satisfied: pluggy<1.0,>=0.12 in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (0.13.1)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (20.4)\n", + "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.6/dist-packages (from pytest==5.4.1->genrl==0.0.1) (20.1.0)\n", + "Requirement already satisfied: protobuf>=3.4.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-tensorboard==1.5.1->genrl==0.0.1) (3.12.4)\n", + "Requirement already satisfied: bleach==1.5.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-tensorboard==1.5.1->genrl==0.0.1) (1.5.0)\n", + "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorflow-tensorboard==1.5.1->genrl==0.0.1) (0.35.1)\n", + "Requirement already satisfied: werkzeug>=0.11.10 in /usr/local/lib/python3.6/dist-packages (from tensorflow-tensorboard==1.5.1->genrl==0.0.1) (1.0.1)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow-tensorboard==1.5.1->genrl==0.0.1) (3.2.2)\n", + "Requirement already satisfied: html5lib==0.9999999 in /usr/local/lib/python3.6/dist-packages (from tensorflow-tensorboard==1.5.1->genrl==0.0.1) (0.9999999)\n", + "Requirement already satisfied: grpcio>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard==1.15.0->genrl==0.0.1) (1.31.0)\n", + "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tensorboard==1.15.0->genrl==0.0.1) (0.8.1)\n", + "Requirement already satisfied: toml in /usr/local/lib/python3.6/dist-packages (from pre-commit==2.4.0->genrl==0.0.1) (0.10.1)\n", + "Requirement already satisfied: nodeenv>=0.11.1 in /usr/local/lib/python3.6/dist-packages (from pre-commit==2.4.0->genrl==0.0.1) (1.5.0)\n", + "Requirement already satisfied: identify>=1.0.0 in /usr/local/lib/python3.6/dist-packages (from pre-commit==2.4.0->genrl==0.0.1) (1.4.29)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.6/dist-packages (from pre-commit==2.4.0->genrl==0.0.1) (5.3.1)\n", + "Requirement already satisfied: virtualenv>=20.0.8 in /usr/local/lib/python3.6/dist-packages (from pre-commit==2.4.0->genrl==0.0.1) (20.0.31)\n", + "Requirement already satisfied: cfgv>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from pre-commit==2.4.0->genrl==0.0.1) (3.2.0)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.12; python_version < \"3.8\"->pytest==5.4.1->genrl==0.0.1) (3.1.0)\n", + "Requirement already satisfied: filelock<4,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from virtualenv>=20.0.8->pre-commit==2.4.0->genrl==0.0.1) (3.0.12)\n", + "Requirement already satisfied: appdirs<2,>=1.4.3 in /usr/local/lib/python3.6/dist-packages (from virtualenv>=20.0.8->pre-commit==2.4.0->genrl==0.0.1) (1.4.4)\n", + "Requirement already satisfied: distlib<1,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from virtualenv>=20.0.8->pre-commit==2.4.0->genrl==0.0.1) (0.3.1)\n", + "Installing collected packages: genrl\n", + " Found existing installation: genrl 0.0.1\n", + " Can't uninstall 'genrl'. No files were found to uninstall.\n", + " Running setup.py develop for genrl\n", + "Successfully installed genrl\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Pg_6SWFJ6HqK", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import torch\n", + "\n", + "from genrl.agents import DQN\n", + "from genrl.agents.deep.dqn.utils import ddqn_q_target, prioritized_q_loss\n", + "from genrl.environments import VectorEnv\n", + "from genrl.trainers import OffPolicyTrainer, OnPolicyTrainer" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v2zEFSI4a_XC", + "colab_type": "text" + }, + "source": [ + "## Training Vanilla DQN on CartPole " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ZHynQcDV6JMw", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 336 + }, + "outputId": "99e7a567-a358-4745-c561-0fbf4a2fcdf0" + }, + "source": [ + "env = VectorEnv(\"CartPole-v0\")\n", + "agent = DQN(\"mlp\", env)\n", + "trainer = OffPolicyTrainer(agent, env, max_timesteps=20000)\n", + "trainer.train()\n", + "trainer.evaluate()" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "timestep Episode value_loss epsilon Episode Reward \n", + "44 0.0 0 0.9785 0 \n", + "240 10.0 0 0.8695 20.1 \n", + "450 20.0 0 0.7117 23.0 \n", + "734 30.0 0 0.559 28.2 \n", + "930 40.0 0 0.4411 19.8 \n", + "1120 50.0 0.3117 0.3654 19.4 \n", + "1226 60.0 0.5066 0.3162 10.6 \n", + "1466 70.0 0.9103 0.268 14.4 \n", + "3290 80.0 2.2416 0.115 156.9 \n", + "5290 90.0 9.0437 0.0259 200.0 \n", + "7288 100.0 21.788 0.0122 200.0 \n", + "9234 110.0 20.1689 0.0103 194.4 \n", + "11234 120.0 18.4776 0.01 200.0 \n", + "13234 130.0 16.0524 0.01 200.0 \n", + "15234 140.0 9.4094 0.01 200.0 \n", + "17234 150.0 6.919 0.01 200.0 \n", + "19086 160.0 10.4121 0.01 192.6 \n", + "Evaluated for 50 episodes, Mean Reward: 200.0, Std Deviation for the Reward: 0.0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LaCoR9jTbFyE", + "colab_type": "text" + }, + "source": [ + "## Extending DQN to Double DQN" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "v8xXyd8nZafc", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class DoubleDQN(DQN):\n", + " def __init__(self, *args, **kwargs):\n", + " super(DoubleDQN, self).__init__(*args, **kwargs)\n", + " self._create_model()\n", + "\n", + " def get_target_q_values(self, next_states, rewards, dones):\n", + " next_q_value_dist = self.model(next_states)\n", + " next_best_actions = torch.argmax(next_q_value_dist, dim=-1).unsqueeze(-1)\n", + " rewards, dones = rewards.unsqueeze(-1), dones.unsqueeze(-1)\n", + " next_q_target_value_dist = self.target_model(next_states)\n", + " max_next_q_target_values = next_q_target_value_dist.gather(2, next_best_actions)\n", + " target_q_values = rewards + agent.gamma * torch.mul(\n", + " max_next_q_target_values, (1 - dones)\n", + " )\n", + " return target_q_values" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "1DCdOVBfZhuV", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 370 + }, + "outputId": "2edcb444-1daa-43da-b89c-c42c78224ab7" + }, + "source": [ + "env = VectorEnv(\"CartPole-v0\")\n", + "agent = DoubleDQN(\"mlp\", env)\n", + "trainer = OffPolicyTrainer(agent, env, max_timesteps=20000)\n", + "trainer.train()\n", + "trainer.evaluate()" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "timestep Episode value_loss epsilon Episode Reward \n", + "26 0.0 0 0.9872 0 \n", + "238 10.0 0 0.8783 19.6 \n", + "404 20.0 0 0.7283 19.1 \n", + "644 30.0 0 0.597 20.1 \n", + "842 40.0 0 0.4812 23.1 \n", + "1054 50.0 0.3035 0.394 16.5 \n", + "1158 60.0 0.4897 0.3374 16.4 \n", + "1288 70.0 0.7634 0.3013 12.8 \n", + "2686 80.0 2.1095 0.1569 101.9 \n", + "4610 90.0 9.0553 0.0399 197.3 \n", + "6372 100.0 19.4667 0.0146 185.1 \n", + "8002 110.0 24.2586 0.0108 166.4 \n", + "9420 120.0 22.9694 0.0102 144.4 \n", + "11102 130.0 16.2487 0.01 162.7 \n", + "12702 140.0 8.8523 0.01 169.4 \n", + "14316 150.0 5.8546 0.01 151.1 \n", + "15550 160.0 7.8662 0.01 132.7 \n", + "17304 170.0 7.644 0.01 170.5 \n", + "19274 180.0 12.3485 0.01 192.9 \n", + "Evaluated for 50 episodes, Mean Reward: 200.0, Std Deviation for the Reward: 0.0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KB9b1ENnbSU4", + "colab_type": "text" + }, + "source": [ + "## Extending DQN to Duelling DQN" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "mveQ_7TZZj3a", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class DuelingDQN(DQN):\n", + " def __init__(self, *args, buffer_type=\"push\", **kwargs):\n", + " super(DuelingDQN, self).__init__(*args, buffer_type=buffer_type, **kwargs)\n", + " self.dqn_type = \"dueling\" # can be \"noisy\" for NoisyDQN\n", + " self._create_model()\n", + "\n", + " def get_target_q_values(self, *args):\n", + " return ddqn_q_target(self, *args)\n", + " \n", + " # Prioritized Loss function needs to be imported only if buffer_type is set as prioritized\n", + " def get_q_loss(self, *args):\n", + " return prioritized_q_loss(self, *args)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "ulUF5KppaNrF", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 252 + }, + "outputId": "febf6c38-6ba9-486e-f1c8-fd4b921c74ab" + }, + "source": [ + "env = VectorEnv(\"CartPole-v0\")\n", + "agent = DuelingDQN(\"mlp\", env, buffer_type=\"prioritized\")\n", + "trainer = OffPolicyTrainer(agent, env, max_timesteps=20000)\n", + "trainer.train()\n", + "trainer.evaluate()" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "timestep Episode value_loss epsilon Episode Reward \n", + "24 0.0 0 0.9882 0 \n", + "182 10.0 0 0.9031 16.4 \n", + "392 20.0 0 0.7536 20.8 \n", + "568 30.0 0 0.6228 17.7 \n", + "764 40.0 0 0.5189 18.8 \n", + "1026 50.0 0.6067 0.4153 26.6 \n", + "1172 60.0 0.4115 0.3398 14.5 \n", + "1282 70.0 0.3327 0.3001 11.8 \n", + "1524 80.0 0.3231 0.2538 18.3 \n", + "2684 90.0 0.2799 0.1375 84.5 \n", + "3830 100.0 0.5504 0.0502 140.3 \n", + "5322 110.0 1.15 0.0212 143.5 \n", + "6916 120.0 2.085 0.0124 151.1 \n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file From 4cba7272e75de40c5ec208e8b6ac0f4f52d44647 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Thu, 3 Sep 2020 23:32:33 +0530 Subject: [PATCH 02/27] initial structure --- genrl/trainers/distributed.py | 205 ++++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 genrl/trainers/distributed.py diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py new file mode 100644 index 00000000..b2b3b30b --- /dev/null +++ b/genrl/trainers/distributed.py @@ -0,0 +1,205 @@ +import copy +import multiprocessing as mp +from typing import Type, Union + +import numpy as np +import reverb +import tensorflow as tf +import torch + +from genrl.trainers import Trainer + + +class ReverbReplayBuffer: + def __init__( + self, + size, + batch_size, + obs_shape, + action_shape, + discrete=True, + reward_shape=(1,), + done_shape=(1,), + n_envs=1, + ): + self.size = size + self.obs_shape = (n_envs, *obs_shape) + self.action_shape = (n_envs, *action_shape) + self.reward_shape = (n_envs, *reward_shape) + self.done_shape = (n_envs, *done_shape) + self.n_envs = n_envs + self.action_dtype = np.int64 if discrete else np.float32 + + self._pos = 0 + self._table = reverb.Table( + name="replay_buffer", + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self.size, + rate_limiter=reverb.rate_limiters.MinSize(2), + ) + self._server = reverb.Server(tables=[self._table], port=None) + self._server_address = f"localhost:{self._server.port}" + self._client = reverb.Client(self._server_address) + self._dataset = reverb.ReplayDataset( + server_address=self._server_address, + table="replay_buffer", + max_in_flight_samples_per_worker=2 * batch_size, + dtypes=(np.float32, self.action_dtype, np.float32, np.float32, np.bool), + shapes=( + tf.TensorShape([n_envs, *obs_shape]), + tf.TensorShape([n_envs, *action_shape]), + tf.TensorShape([n_envs, *reward_shape]), + tf.TensorShape([n_envs, *obs_shape]), + tf.TensorShape([n_envs, *done_shape]), + ), + ) + self._iterator = self._dataset.batch(batch_size).as_numpy_iterator() + + def push(self, inp): + i = [] + i.append(np.array(inp[0], dtype=np.float32).reshape(self.obs_shape)) + i.append(np.array(inp[1], dtype=self.action_dtype).reshape(self.action_shape)) + i.append(np.array(inp[2], dtype=np.float32).reshape(self.reward_shape)) + i.append(np.array(inp[3], dtype=np.float32).reshape(self.obs_shape)) + i.append(np.array(inp[4], dtype=np.bool).reshape(self.done_shape)) + + self._client.insert(i, priorities={"replay_buffer": 1.0}) + if self._pos < self.size: + self._pos += 1 + + def extend(self, inp): + for sample in inp: + self.push(sample) + + def sample(self, *args, **kwargs): + sample = next(self._iterator) + obs, a, r, next_obs, d = [torch.from_numpy(t).float() for t in sample.data] + return obs, a, r.reshape(-1, self.n_envs), next_obs, d.reshape(-1, self.n_envs) + + def __len__(self): + return self._pos + + def __del__(self): + self._server.stop() + + +class DistributedOffPolicyTrainer(Trainer): + """Distributed Off Policy Trainer Class + + Trainer class for Distributed Off Policy Agents + + """ + + def __init__( + self, + *args, + env, + agent, + max_ep_len: int = 500, + max_timesteps: int = 5000, + update_interval: int = 50, + buffer_server_port=None, + param_server_port=None, + **kwargs, + ): + super(DistributedOffPolicyTrainer, self).__init__( + *args, off_policy=True, max_timesteps=max_timesteps, **kwargs + ) + self.env = env + self.agent = agent + self.max_ep_len = max_ep_len + self.update_interval = update_interval + self.buffer_server_port = buffer_server_port + self.param_server_port = param_server_port + + def train(self, n_actors, max_buffer_size, batch_size, max_updates): + buffer_server = reverb.Server( + tables=[ + reverb.Table( + name="replay_buffer", + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_buffer_size, + rate_limiter=reverb.rate_limiters.MinSize(2), + ) + ], + port=self.buffer_server_port, + ) + buffer_server_address = f"localhost:{self.buffer_server.port}" + + param_server = reverb.Server( + tables=[ + reverb.Table( + name="replay_buffer", + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=1, + ) + ], + port=self.param_server_port, + ) + param_server_address = f"localhost:{self.param_server.port}" + + actor_procs = [] + for _ in range(n_actors): + p = mp.Process( + target=self._run_actor, + args=( + copy.deepcopy(self.agent), + copy.deepcopy(self.env), + buffer_server_address, + param_server_address, + ), + ) + p.daemon = True + actor_procs.append(p) + + learner_proc = mp.Process( + target=self._run_learner, + args=( + self.agent, + max_updates, + buffer_server_address, + param_server_address, + batch_size, + ), + ) + learner_proc.daemon = True + + def _run_actor(self, agent, env, buffer_server_address, param_server_address): + buffer_client = reverb.Client(buffer_server_address) + param_client = reverb.Client(param_server_address) + + state = env.reset() + + while True: + params = param_client.sample(table="replay_buffer") + agent.load_weights(params) + + action = self.get_action(state) + next_state, reward, done, info = self.env.step(action) + + state = next_state.clone() + + buffer_client.insert([state, action, reward, done, next_state]) + + def _run_learner( + self, + agent, + max_updates, + buffer_server_address, + param_server_address, + batch_size, + ): + param_client = reverb.Client(param_server_address) + dataset = reverb.ReplayDataset( + server_address=buffer_server_address, + table="replay_buffer", + ) + data_iter = dataset.batch(batch_size).as_numpy_iterator() + + for _ in range(max_updates): + sample = next(data_iter) + agent.update_params(sample) + param_client.insert(agent.get_weights()) From 3d233c47d625f7749c46fc0cea6dcdf35e42b0a6 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Mon, 7 Sep 2020 00:01:44 +0530 Subject: [PATCH 03/27] added mp --- examples/distributed.py | 76 +++++++++ genrl/agents/deep/base/offpolicy.py | 2 +- genrl/agents/deep/ddpg/ddpg.py | 3 + genrl/trainers/distributed.py | 236 +++++++++++++--------------- 4 files changed, 187 insertions(+), 130 deletions(-) create mode 100644 examples/distributed.py diff --git a/examples/distributed.py b/examples/distributed.py new file mode 100644 index 00000000..ead8c8c6 --- /dev/null +++ b/examples/distributed.py @@ -0,0 +1,76 @@ +from genrl.agents import DDPG +from genrl.trainers import OffPolicyTrainer +from genrl.trainers.distributed import DistributedOffPolicyTrainer +from genrl.environments import VectorEnv +import gym +import reverb +import numpy as np +import multiprocessing as mp +import threading + + +# env = VectorEnv("Pendulum-v0") +# agent = DDPG("mlp", env) +# trainer = OffPolicyTrainer(agent, env) +# trainer.train() + + +env = gym.make("Pendulum-v0") +agent = DDPG("mlp", env) + +# o = env.reset() +# action = agent.select_action(o) +# next_state, reward, done, info = env.step(action.numpy()) + +# buffer_server = reverb.Server( +# tables=[ +# reverb.Table( +# name="replay_buffer", +# sampler=reverb.selectors.Uniform(), +# remover=reverb.selectors.Fifo(), +# max_size=10, +# rate_limiter=reverb.rate_limiters.MinSize(4), +# ) +# ], +# port=None, +# ) +# client = reverb.Client(f"localhost:{buffer_server.port}") +# print(client.server_info()) + +# state = env.reset() +# action = agent.select_action(state) +# next_state, reward, done, info = env.step(action.numpy()) + +# state = next_state.copy() +# print(client.server_info()) +# print("going to insert") +# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) +# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) +# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) +# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) +# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) +# print("inserted") + +# # print(list(client.sample('replay_buffer', num_samples=1))) + + +# def sample(address): +# print("-- entered proc") +# client = reverb.Client(address) +# print("-- started client") +# print(list(client.sample('replay_buffer', num_samples=1))) + +# a = f"localhost:{buffer_server.port}" +# print("create process") +# # p = mp.Process(target=sample, args=(a,)) +# p = threading.Thread(target=sample, args=(a,)) +# print("start process") +# p.start() +# print("wait process") +# p.join() +# print("end process") + +trainer = DistributedOffPolicyTrainer(agent, env) +trainer.train( + n_actors=2, max_buffer_size=100, batch_size=4, max_updates=10, update_interval=1 +) diff --git a/genrl/agents/deep/base/offpolicy.py b/genrl/agents/deep/base/offpolicy.py index 0f294042..8e14d186 100644 --- a/genrl/agents/deep/base/offpolicy.py +++ b/genrl/agents/deep/base/offpolicy.py @@ -106,7 +106,7 @@ def sample_from_buffer(self, beta: float = None): *[states, actions, rewards, next_states, dones, indices, weights] ) else: - raise NotImplementedError + batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones]) return batch def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor: diff --git a/genrl/agents/deep/ddpg/ddpg.py b/genrl/agents/deep/ddpg/ddpg.py index b21b6808..a1122299 100644 --- a/genrl/agents/deep/ddpg/ddpg.py +++ b/genrl/agents/deep/ddpg/ddpg.py @@ -123,6 +123,9 @@ def get_hyperparams(self) -> Dict[str, Any]: } return hyperparams + def get_weights(self): + return self.ac.state_dict() + def get_logging_params(self) -> Dict[str, Any]: """Gets relevant parameters for logging diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py index b2b3b30b..a768820b 100644 --- a/genrl/trainers/distributed.py +++ b/genrl/trainers/distributed.py @@ -1,7 +1,9 @@ import copy import multiprocessing as mp +import threading from typing import Type, Union +import gym import numpy as np import reverb import tensorflow as tf @@ -10,81 +12,7 @@ from genrl.trainers import Trainer -class ReverbReplayBuffer: - def __init__( - self, - size, - batch_size, - obs_shape, - action_shape, - discrete=True, - reward_shape=(1,), - done_shape=(1,), - n_envs=1, - ): - self.size = size - self.obs_shape = (n_envs, *obs_shape) - self.action_shape = (n_envs, *action_shape) - self.reward_shape = (n_envs, *reward_shape) - self.done_shape = (n_envs, *done_shape) - self.n_envs = n_envs - self.action_dtype = np.int64 if discrete else np.float32 - - self._pos = 0 - self._table = reverb.Table( - name="replay_buffer", - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self.size, - rate_limiter=reverb.rate_limiters.MinSize(2), - ) - self._server = reverb.Server(tables=[self._table], port=None) - self._server_address = f"localhost:{self._server.port}" - self._client = reverb.Client(self._server_address) - self._dataset = reverb.ReplayDataset( - server_address=self._server_address, - table="replay_buffer", - max_in_flight_samples_per_worker=2 * batch_size, - dtypes=(np.float32, self.action_dtype, np.float32, np.float32, np.bool), - shapes=( - tf.TensorShape([n_envs, *obs_shape]), - tf.TensorShape([n_envs, *action_shape]), - tf.TensorShape([n_envs, *reward_shape]), - tf.TensorShape([n_envs, *obs_shape]), - tf.TensorShape([n_envs, *done_shape]), - ), - ) - self._iterator = self._dataset.batch(batch_size).as_numpy_iterator() - - def push(self, inp): - i = [] - i.append(np.array(inp[0], dtype=np.float32).reshape(self.obs_shape)) - i.append(np.array(inp[1], dtype=self.action_dtype).reshape(self.action_shape)) - i.append(np.array(inp[2], dtype=np.float32).reshape(self.reward_shape)) - i.append(np.array(inp[3], dtype=np.float32).reshape(self.obs_shape)) - i.append(np.array(inp[4], dtype=np.bool).reshape(self.done_shape)) - - self._client.insert(i, priorities={"replay_buffer": 1.0}) - if self._pos < self.size: - self._pos += 1 - - def extend(self, inp): - for sample in inp: - self.push(sample) - - def sample(self, *args, **kwargs): - sample = next(self._iterator) - obs, a, r, next_obs, d = [torch.from_numpy(t).float() for t in sample.data] - return obs, a, r.reshape(-1, self.n_envs), next_obs, d.reshape(-1, self.n_envs) - - def __len__(self): - return self._pos - - def __del__(self): - self._server.stop() - - -class DistributedOffPolicyTrainer(Trainer): +class DistributedOffPolicyTrainer: """Distributed Off Policy Trainer Class Trainer class for Distributed Off Policy Agents @@ -93,27 +21,20 @@ class DistributedOffPolicyTrainer(Trainer): def __init__( self, - *args, - env, agent, - max_ep_len: int = 500, - max_timesteps: int = 5000, - update_interval: int = 50, + env, buffer_server_port=None, param_server_port=None, **kwargs, ): - super(DistributedOffPolicyTrainer, self).__init__( - *args, off_policy=True, max_timesteps=max_timesteps, **kwargs - ) self.env = env self.agent = agent - self.max_ep_len = max_ep_len - self.update_interval = update_interval self.buffer_server_port = buffer_server_port self.param_server_port = param_server_port - def train(self, n_actors, max_buffer_size, batch_size, max_updates): + def train( + self, n_actors, max_buffer_size, batch_size, max_updates, update_interval + ): buffer_server = reverb.Server( tables=[ reverb.Table( @@ -121,85 +42,142 @@ def train(self, n_actors, max_buffer_size, batch_size, max_updates): sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_buffer_size, - rate_limiter=reverb.rate_limiters.MinSize(2), + rate_limiter=reverb.rate_limiters.MinSize(1), ) ], port=self.buffer_server_port, ) - buffer_server_address = f"localhost:{self.buffer_server.port}" + buffer_server_address = f"localhost:{buffer_server.port}" param_server = reverb.Server( tables=[ reverb.Table( - name="replay_buffer", + name="param_buffer", sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=1, + rate_limiter=reverb.rate_limiters.MinSize(1), ) ], port=self.param_server_port, ) - param_server_address = f"localhost:{self.param_server.port}" + param_server_address = f"localhost:{param_server.port}" actor_procs = [] for _ in range(n_actors): - p = mp.Process( - target=self._run_actor, + p = threading.Thread( + target=run_actor, args=( copy.deepcopy(self.agent), copy.deepcopy(self.env), buffer_server_address, param_server_address, ), + daemon=True, ) - p.daemon = True + p.start() actor_procs.append(p) - learner_proc = mp.Process( - target=self._run_learner, + learner_proc = threading.Thread( + target=run_learner, args=( - self.agent, + copy.deepcopy(self.agent), max_updates, + update_interval, buffer_server_address, param_server_address, batch_size, ), + daemon=True, ) learner_proc.daemon = True + learner_proc.start() + learner_proc.join() + + # param_client = reverb.Client(param_server_address) + # self.agent.replay_buffer = ReverbReplayDataset( + # self.agent.env, buffer_server_address, batch_size + # ) + + # for _ in range(max_updates): + # self.agent.update_params(update_interval) + # params = self.agent.get_weights() + # param_client.insert(params.values(), {"param_buffer": 1}) + # print("weights updated") + # # print(list(param_client.sample("param_buffer"))) + + +def run_actor(agent, env, buffer_server_address, param_server_address): + buffer_client = reverb.Client(buffer_server_address) + param_client = reverb.TFClient(param_server_address) + + state = env.reset().astype(np.float32) + + for i in range(10): + # params = param_client.sample("param_buffer", []) + # print("Sampling done") + # print(list(params)) + # agent.load_weights(params) + + action = agent.select_action(state).numpy() + next_state, reward, done, _ = env.step(action) + next_state = next_state.astype(np.float32) + reward = np.array([reward]).astype(np.float32) + done = np.array([done]).astype(np.bool) + + buffer_client.insert([state, action, reward, next_state, done], {"replay_buffer": 1}) + print("transition inserted") + state = env.reset().astype(np.float32) if done else next_state.copy() + + +def run_learner( + agent, + max_updates, + update_interval, + buffer_server_address, + param_server_address, + batch_size, +): + param_client = reverb.Client(param_server_address) + agent.replay_buffer = ReverbReplayDataset( + agent.env, buffer_server_address, batch_size + ) + for _ in range(max_updates): + agent.update_params(update_interval) + params = agent.get_weights() + param_client.insert(params.values(), {"param_buffer": 1}) + print("weights updated") + # print(list(param_client.sample("param_buffer"))) + + +class ReverbReplayDataset: + def __init__(self, env, address, batch_size): + action_dtype = ( + np.int64 + if isinstance(env.action_space, gym.spaces.discrete.Discrete) + else np.float32 + ) + obs_shape = env.observation_space.shape + action_shape = env.action_space.shape + reward_shape = 1 + done_shape = 1 - def _run_actor(self, agent, env, buffer_server_address, param_server_address): - buffer_client = reverb.Client(buffer_server_address) - param_client = reverb.Client(param_server_address) - - state = env.reset() - - while True: - params = param_client.sample(table="replay_buffer") - agent.load_weights(params) - - action = self.get_action(state) - next_state, reward, done, info = self.env.step(action) - - state = next_state.clone() - - buffer_client.insert([state, action, reward, done, next_state]) - - def _run_learner( - self, - agent, - max_updates, - buffer_server_address, - param_server_address, - batch_size, - ): - param_client = reverb.Client(param_server_address) - dataset = reverb.ReplayDataset( - server_address=buffer_server_address, + self._dataset = reverb.ReplayDataset( + server_address=address, table="replay_buffer", + max_in_flight_samples_per_worker=2 * batch_size, + dtypes=(np.float32, action_dtype, np.float32, np.float32, np.bool), + shapes=( + tf.TensorShape(obs_shape), + tf.TensorShape(action_shape), + tf.TensorShape(reward_shape), + tf.TensorShape(obs_shape), + tf.TensorShape(done_shape), + ), ) - data_iter = dataset.batch(batch_size).as_numpy_iterator() + self._data_iter = self._dataset.batch(batch_size).as_numpy_iterator() - for _ in range(max_updates): - sample = next(data_iter) - agent.update_params(sample) - param_client.insert(agent.get_weights()) + def sample(self, *args, **kwargs): + sample = next(self._data_iter) + obs, a, r, next_obs, d = [torch.from_numpy(t).float() for t in sample.data] + return obs, a, r, next_obs, d From 1c504cc3142bd210b070abda782cccc644856e3c Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Sun, 20 Sep 2020 19:38:18 +0530 Subject: [PATCH 04/27] add files --- examples/distributed2.py | 275 +++++++++++++++++++++++++ genrl/agents/deep/base/offpolicy.py | 2 +- genrl/core/buffers.py | 3 + genrl/distributed/actor.py | 0 genrl/distributed/parameter_server.py | 0 genrl/distributed/transition_server.py | 2 + 6 files changed, 281 insertions(+), 1 deletion(-) create mode 100644 examples/distributed2.py create mode 100644 genrl/distributed/actor.py create mode 100644 genrl/distributed/parameter_server.py create mode 100644 genrl/distributed/transition_server.py diff --git a/examples/distributed2.py b/examples/distributed2.py new file mode 100644 index 00000000..5abf7a95 --- /dev/null +++ b/examples/distributed2.py @@ -0,0 +1,275 @@ +from genrl.core.buffers import ReplayBuffer +import os + +from genrl.agents import DDPG +import torch +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +from torch import optim +import argparse + +import copy + +import gym +import numpy as np + +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "29500" + +# to call a function on an rref, we could do the following +# _remote_method(some_func, rref, *args) + +def _call_method(method, rref, *args, **kwargs): + return method(rref.local_value(), *args, **kwargs) + + +def _remote_method(method, rref, *args, **kwargs): + args = [method, rref] + list(args) + return rpc.rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs) + + +gloabl_lock = mp.Lock() + + +class ParamServer: + def __init__(self, init_params): + self.params = init_params + # self.lock = mp.Lock() + + def store_params(self, new_params): + # with self.lock: + with gloabl_lock: + self.params = new_params + + def get_params(self): + # with self.lock: + with gloabl_lock: + return self.params + + +class DistributedReplayBuffer: + def __init__(self, size): + self.size = size + self.len = 0 + self._buffer = ReplayBuffer(self.size) + + +class DistributedOffPolicyTrainer: + """Distributed Off Policy Trainer Class + + Trainer class for Distributed Off Policy Agents + + """ + def __init__( + self, + agent, + env, + **kwargs, + ): + self.env = env + self.agent = agent + + def train( + self, n_actors, max_buffer_size, batch_size, max_updates, update_interval + ): + + print("a") + world_size = n_actors + 2 + completed = mp.Value("i", 0) + print("a") + param_server_rref_q = mp.Queue(1) + param_server_p = mp.Process( + target=run_param_server, args=(param_server_rref_q, world_size,) + ) + param_server_p.start() + param_server_rref = param_server_rref_q.get() + param_server_rref_q.close() + + print("a") + buffer_rref_q = mp.Queue(1) + buffer_p = mp.Process(target=run_buffer, args=(max_buffer_size, buffer_rref_q, world_size,)) + buffer_p.start() + buffer_rref = buffer_rref_q.get() + buffer_rref_q.close() + print("a") + + actor_ps = [] + for i in range(n_actors): + a_p = mp.Process( + target=run_actor, + args=( + i, + copy.deepcopy(self.agent), + copy.deepcopy(self.env), + param_server_rref,~ + buffer_rref, + world_size, + completed + ), + ) + a_p.start() + actor_ps.append(a_p) + + learner_p = mp.Process( + target=run_learner, + args=(max_updates, batch_size, self.agent, param_server_rref, buffer_rref, world_size, completed), + ) + learner_p.start() + + learner_p.join() + for a in actor_ps: + a.join() + buffer_p.join() + param_server_p.join() + + +def run_param_server(q, world_size): + print("Running parameter server") + rpc.init_rpc(name="param_server", rank=0, world_size=world_size) + print("d") + param_server = ParamServer(None) + param_server_rref = rpc.RRef(param_server) + q.put(param_server_rref) + rpc.shutdown() + print("param server shutting down") + + +def run_buffer(max_buffer_size, q, world_size): + print("Running buffer server") + rpc.init_rpc(name="buffer", rank=1, world_size=world_size) + buffer = ReplayBuffer(max_buffer_size) + buffer_rref = rpc.RRef(buffer) + q.put(buffer_rref) + rpc.shutdown() + print("buffer shutting down") + + +def run_learner(max_updates, batch_size, agent, param_server_rref, buffer_rref, world_size, completed): + print("Running learner") + rpc.init_rpc(name="learner", rank=world_size - 1, world_size=world_size) + i = 0 + while i < max_updates: + batch = _remote_method(ReplayBuffer.sample, buffer_rref, batch_size) + if batch is None: + continue + agent.update_params(batch) + _remote_method(ParamServer.store_params, param_server_rref, agent.get_weights()) + print("weights updated") + i += 1 + print(i) + completed.value = 1 + rpc.shutdown() + print("learner shutting down") + + +def run_actor(i, agent, env, param_server_rref, buffer_rref, world_size, completed): + print(f"Running actor {i}") + + rpc.init_rpc(name=f"action_{i}", rank=i + 1, world_size=world_size) + + state = env.reset().astype(np.float32) + + while not completed.value == 1: + params = _remote_method(ParamServer.get_params, param_server_rref) + agent.load_weights(params) + + action = agent.select_action(state).numpy() + next_state, reward, done, _ = env.step(action) + next_state = next_state.astype(np.float32) + reward = np.array([reward]).astype(np.float32) + done = np.array([done]).astype(np.bool) + + print("attempting to insert transition") + _remote_method(ReplayBuffer.push, buffer_rref, [state, action, reward, next_state, done]) + print("inserted transition") + state = env.reset().astype(np.float32) if done else next_state.copy() + + rpc.shutdown() + print("actor shutting down") + +env = gym.make("Pendulum-v0") +agent = DDPG("mlp", env) + +trainer = DistributedOffPolicyTrainer(agent, env) +trainer.train( + n_actors=1, max_buffer_size=100, batch_size=1, max_updates=100, update_interval=1 +) + + +# if __name__ == '__main__': +# parser = argparse.ArgumentParser( +# description="Parameter-Server RPC based training") +# parser.add_argument( +# "--world_size", +# type=int, +# default=4, +# help="""Total number of participating processes. Should be the number +# of actors + 3.""") +# parser.add_argument( +# "--run", +# type=str, +# default="param_server", +# choices=["param_server", "buffer", "learner", "actor"], +# help="Which program to run") +# parser.add_argument( +# "--master_addr", +# type=str, +# default="localhost", +# help="""Address of master, will default to localhost if not provided. +# Master must be able to accept network traffic on the address + port.""") +# parser.add_argument( +# "--master_port", +# type=str, +# default="29500", +# help="""Port that master is listening on, will default to 29500 if not +# provided. Master must be able to accept network traffic on the host and port.""") + +# args = parser.parse_args() + +# os.environ['MASTER_ADDR'] = args.master_addr +# os.environ["MASTER_PORT"] = args.master_port + +# processes = [] +# world_size = args.world_size +# if args.run == "param_server": +# p = mp.Process(target=run_param_server, args=(world_size)) +# p.start() +# processes.append(p) +# elif args.run == "buffer": +# p = mp.Process(target=run_buffer, args=(world_size)) +# p.start() +# processes.append(p) +# # Get data to train on +# train_loader = torch.utils.data.DataLoader( +# datasets.MNIST('../data', train=True, download=True, +# transform=transforms.Compose([ +# transforms.ToTensor(), +# transforms.Normalize((0.1307,), (0.3081,)) +# ])), +# batch_size=32, shuffle=True,) +# test_loader = torch.utils.data.DataLoader( +# datasets.MNIST( +# '../data', +# train=False, +# transform=transforms.Compose([ +# transforms.ToTensor(), +# transforms.Normalize((0.1307,), (0.3081,)) +# ])), +# batch_size=32, +# shuffle=True, +# ) +# # start training worker on this node +# p = mp.Process( +# target=run_worker, +# args=( +# args.rank, +# world_size, args.num_gpus, +# train_loader, +# test_loader)) +# p.start() +# processes.append(p) + +# for p in processes: +# p.join() diff --git a/genrl/agents/deep/base/offpolicy.py b/genrl/agents/deep/base/offpolicy.py index 8e14d186..459ed58d 100644 --- a/genrl/agents/deep/base/offpolicy.py +++ b/genrl/agents/deep/base/offpolicy.py @@ -277,4 +277,4 @@ def load_weights(self, weights) -> None: Args: weights (:obj:`dict`): Dictionary of different neural net weights """ - self.ac.load_state_dict(weights["weights"]) + self.ac.load_state_dict(weights) diff --git a/genrl/core/buffers.py b/genrl/core/buffers.py index 0a5b6e7c..b73067f1 100644 --- a/genrl/core/buffers.py +++ b/genrl/core/buffers.py @@ -57,6 +57,9 @@ def sample( :returns: (Tuple composing of `state`, `action`, `reward`, `next_state` and `done`) """ + if batch_size > len(self.memory): + return None + batch = random.sample(self.memory, batch_size) state, action, reward, next_state, done = map(np.stack, zip(*batch)) return [ diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/distributed/parameter_server.py b/genrl/distributed/parameter_server.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/distributed/transition_server.py b/genrl/distributed/transition_server.py new file mode 100644 index 00000000..a8f6e256 --- /dev/null +++ b/genrl/distributed/transition_server.py @@ -0,0 +1,2 @@ +from genrl.core.buffers import ReplayBuffer + From 2c3298a725727f991b5ae26db6fe7ca85734e321 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Thu, 1 Oct 2020 18:05:23 +0530 Subject: [PATCH 05/27] added new structure on rpc --- examples/distributed.py | 133 +++++++++--------- examples/distributed_old_1.py | 76 ++++++++++ .../{distributed2.py => distributed_old_2.py} | 0 genrl/distributed/__init__.py | 5 + genrl/distributed/actor.py | 74 ++++++++++ genrl/distributed/core.py | 92 ++++++++++++ genrl/distributed/experience_server.py | 25 ++++ genrl/distributed/learner.py | 48 +++++++ genrl/distributed/parameter_server.py | 37 +++++ genrl/distributed/transition_server.py | 2 - genrl/distributed/utils.py | 33 +++++ 11 files changed, 458 insertions(+), 67 deletions(-) create mode 100644 examples/distributed_old_1.py rename examples/{distributed2.py => distributed_old_2.py} (100%) create mode 100644 genrl/distributed/__init__.py create mode 100644 genrl/distributed/core.py create mode 100644 genrl/distributed/experience_server.py create mode 100644 genrl/distributed/learner.py delete mode 100644 genrl/distributed/transition_server.py create mode 100644 genrl/distributed/utils.py diff --git a/examples/distributed.py b/examples/distributed.py index ead8c8c6..76002cde 100644 --- a/examples/distributed.py +++ b/examples/distributed.py @@ -1,76 +1,79 @@ +from genrl.distributed import ( + Master, + ExperienceServer, + ParameterServer, + ActorNode, + LearnerNode, + remote_method, + DistributedTrainer, + WeightHolder, +) +from genrl.core import ReplayBuffer from genrl.agents import DDPG -from genrl.trainers import OffPolicyTrainer -from genrl.trainers.distributed import DistributedOffPolicyTrainer -from genrl.environments import VectorEnv import gym -import reverb -import numpy as np -import multiprocessing as mp -import threading - - -# env = VectorEnv("Pendulum-v0") -# agent = DDPG("mlp", env) -# trainer = OffPolicyTrainer(agent, env) -# trainer.train() - -env = gym.make("Pendulum-v0") -agent = DDPG("mlp", env) - -# o = env.reset() -# action = agent.select_action(o) -# next_state, reward, done, info = env.step(action.numpy()) -# buffer_server = reverb.Server( -# tables=[ -# reverb.Table( -# name="replay_buffer", -# sampler=reverb.selectors.Uniform(), -# remover=reverb.selectors.Fifo(), -# max_size=10, -# rate_limiter=reverb.rate_limiters.MinSize(4), -# ) -# ], -# port=None, -# ) -# client = reverb.Client(f"localhost:{buffer_server.port}") -# print(client.server_info()) +N_ACTORS = 2 +BUFFER_SIZE = 10 +MAX_ENV_STEPS = 100 +TRAIN_STEPS = 10 +BATCH_SIZE = 5 -# state = env.reset() -# action = agent.select_action(state) -# next_state, reward, done, info = env.step(action.numpy()) -# state = next_state.copy() -# print(client.server_info()) -# print("going to insert") -# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) -# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) -# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) -# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) -# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) -# print("inserted") +def collect_experience(agent, experience_server_rref): + obs = agent.env.reset() + done = False + for i in range(MAX_ENV_STEPS): + action = agent.select_action(obs) + next_obs, reward, done, info = agent.env.step() + print("Sending experience") + remote_method(ReplayBuffer.push, experience_server_rref) + print("Done sending experience") + if done: + break -# # print(list(client.sample('replay_buffer', num_samples=1))) +class MyTrainer(DistributedTrainer): + def __init__(self, agent, train_steps, batch_size): + super(MyTrainer, self).__init__(agent) + self.train_steps = train_steps + self.batch_size = batch_size -# def sample(address): -# print("-- entered proc") -# client = reverb.Client(address) -# print("-- started client") -# print(list(client.sample('replay_buffer', num_samples=1))) + def train(self, parameter_server_rref, experience_server_rref): + for i in range(self.train_steps): + batch = remote_method( + ReplayBuffer.sample, parameter_server_rref, self.batch_size + ) + if batch is None: + continue + self.agent.update_params(batch) + print("Storing weights") + remote_method( + WeightHolder.store_weights, + parameter_server_rref, + self.agent.get_weights(), + ) + print("Done storing weights") -# a = f"localhost:{buffer_server.port}" -# print("create process") -# # p = mp.Process(target=sample, args=(a,)) -# p = threading.Thread(target=sample, args=(a,)) -# print("start process") -# p.start() -# print("wait process") -# p.join() -# print("end process") -trainer = DistributedOffPolicyTrainer(agent, env) -trainer.train( - n_actors=2, max_buffer_size=100, batch_size=4, max_updates=10, update_interval=1 -) +master = Master(world_size=6, address="localhost") +print("inited master") +env = gym.make("Pendulum-v0") +agent = DDPG(env) +parameter_server = ParameterServer("param-0", master, agent.get_weights, rank=1) +experience_server = ExperienceServer("experience-0", master, BUFFER_SIZE, rank=2) +trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE) +learner = LearnerNode("learner-0", master, parameter_server, experience_server, trainer, rank=3) +actors = [ + ActorNode( + f"actor-{i}", + master, + parameter_server, + experience_server, + learner, + agent, + collect_experience, + rank=i+4 + ) + for i in range(N_ACTORS) +] diff --git a/examples/distributed_old_1.py b/examples/distributed_old_1.py new file mode 100644 index 00000000..ead8c8c6 --- /dev/null +++ b/examples/distributed_old_1.py @@ -0,0 +1,76 @@ +from genrl.agents import DDPG +from genrl.trainers import OffPolicyTrainer +from genrl.trainers.distributed import DistributedOffPolicyTrainer +from genrl.environments import VectorEnv +import gym +import reverb +import numpy as np +import multiprocessing as mp +import threading + + +# env = VectorEnv("Pendulum-v0") +# agent = DDPG("mlp", env) +# trainer = OffPolicyTrainer(agent, env) +# trainer.train() + + +env = gym.make("Pendulum-v0") +agent = DDPG("mlp", env) + +# o = env.reset() +# action = agent.select_action(o) +# next_state, reward, done, info = env.step(action.numpy()) + +# buffer_server = reverb.Server( +# tables=[ +# reverb.Table( +# name="replay_buffer", +# sampler=reverb.selectors.Uniform(), +# remover=reverb.selectors.Fifo(), +# max_size=10, +# rate_limiter=reverb.rate_limiters.MinSize(4), +# ) +# ], +# port=None, +# ) +# client = reverb.Client(f"localhost:{buffer_server.port}") +# print(client.server_info()) + +# state = env.reset() +# action = agent.select_action(state) +# next_state, reward, done, info = env.step(action.numpy()) + +# state = next_state.copy() +# print(client.server_info()) +# print("going to insert") +# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) +# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) +# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) +# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) +# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) +# print("inserted") + +# # print(list(client.sample('replay_buffer', num_samples=1))) + + +# def sample(address): +# print("-- entered proc") +# client = reverb.Client(address) +# print("-- started client") +# print(list(client.sample('replay_buffer', num_samples=1))) + +# a = f"localhost:{buffer_server.port}" +# print("create process") +# # p = mp.Process(target=sample, args=(a,)) +# p = threading.Thread(target=sample, args=(a,)) +# print("start process") +# p.start() +# print("wait process") +# p.join() +# print("end process") + +trainer = DistributedOffPolicyTrainer(agent, env) +trainer.train( + n_actors=2, max_buffer_size=100, batch_size=4, max_updates=10, update_interval=1 +) diff --git a/examples/distributed2.py b/examples/distributed_old_2.py similarity index 100% rename from examples/distributed2.py rename to examples/distributed_old_2.py diff --git a/genrl/distributed/__init__.py b/genrl/distributed/__init__.py new file mode 100644 index 00000000..54751712 --- /dev/null +++ b/genrl/distributed/__init__.py @@ -0,0 +1,5 @@ +from genrl.distributed.core import Master, DistributedTrainer, remote_method, Node +from genrl.distributed.parameter_server import ParameterServer, WeightHolder +from genrl.distributed.experience_server import ExperienceServer +from genrl.distributed.actor import ActorNode +from genrl.distributed.learner import LearnerNode diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py index e69de29b..2f1deaa6 100644 --- a/genrl/distributed/actor.py +++ b/genrl/distributed/actor.py @@ -0,0 +1,74 @@ +from genrl.distributed.core import ( + DistributedTrainer, + Master, + Node, +) +from genrl.distributed.parameter_server import WeightHolder +from genrl.distributed.utils import remote_method, set_environ + +import torch.multiprocessing as mp +import torch.distributed.rpc as rpc + + +class ActorNode(Node): + def __init__( + self, + name, + master, + parameter_server, + experience_server, + learner, + agent, + collect_experience, + rank=None + ): + super(ActorNode, self).__init__(name, master, rank) + self.parameter_server = parameter_server + self.experience_server = experience_server + + mp.Process( + target=self.run_paramater_server, + args=( + name, + master.rref, + master.address, + master.port, + master.world_size, + self.rank, + parameter_server.rref, + experience_server.rref, + learner.rref, + agent, + collect_experience, + ), + ) + + @staticmethod + def train( + name, + master_rref, + master_address, + master_port, + world_size, + rank, + parameter_server_rref, + experience_server_rref, + learner_rref, + agent, + collect_experience, + + ): + print("Starting Actor") + set_environ(master_address, master_port) + rpc.init_rpc(name=name, world_size=world_size, rank=rank) + remote_method(Master.store_rref, master_rref, rpc.RRef(agent), name) + while not remote_method(DistributedTrainer.is_done(), learner_rref): + agent.load_weights( + remote_method(WeightHolder.get_weights(), parameter_server_rref) + ) + print("Done loadiing weights") + collect_experience(agent, experience_server_rref) + print("Done collecting experience") + + rpc.shutdown() + print("Shutdown actor") diff --git a/genrl/distributed/core.py b/genrl/distributed/core.py new file mode 100644 index 00000000..c89952a9 --- /dev/null +++ b/genrl/distributed/core.py @@ -0,0 +1,92 @@ +import torch.distributed.rpc as rpc + +from genrl.distributed.utils import remote_method, set_environ + + +class Node: + def __init__(self, name, master, rank): + self._name = name + self.master = master + self.master.increment_node_count() + if rank is None: + self._rank = master.node_count + elif rank > 0 and rank < master.world_size: + self._rank = rank + elif rank == 0: + raise ValueError("Rank of 0 is invalid for node") + elif rank >= master.world_size: + raise ValueError("Specified rank greater than allowed by world size") + else: + raise ValueError("Invalid value of rank") + + @property + def rref(self): + return self.master[self.name] + + @property + def rank(self): + return self._rank + + +class Master(Node): + def __init__(self, world_size, address="localhost", port=29500): + print("initing master") + set_environ(address, port) + print("initing master") + rpc.init_rpc(name="master", rank=0, world_size=world_size) + print("initing master") + self._world_size = world_size + self._rref = rpc.RRef(self) + print("initing master") + self._rref_reg = {} + self._address = address + self._port = port + self._node_counter = 0 + + def store_rref(self, parent_rref, idx): + self._rref_reg[idx] = parent_rref + + def fetch_rref(self, idx): + return self._rref_reg[idx] + + @property + def rref(self): + return self._rref + + @property + def world_size(self): + return self._world_size + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + @property + def node_count(self): + return self._node_counter + + def increment_node_counter(self): + self._node_counter += 1 + if self.node_count >= self.world_size: + raise Exception("Attempt made to add more nodes than specified by world size") + + +class DistributedTrainer: + def __init__(self, agent): + self.agent = agent + self._completed_training_flag = False + + def train(self, parameter_server_rref, experience_server_rref): + raise NotImplementedError + + def train_wrapper(self, parameter_server_rref, experience_server_rref): + self._completed_training_flag = False + self.train(parameter_server_rref, experience_server_rref) + self._completed_training_flag = True + + def is_done(self): + return self._completed_training_flag diff --git a/genrl/distributed/experience_server.py b/genrl/distributed/experience_server.py new file mode 100644 index 00000000..5a524e92 --- /dev/null +++ b/genrl/distributed/experience_server.py @@ -0,0 +1,25 @@ +from genrl.distributed import Master, Node +from genrl.distributed.utils import remote_method, set_environ + +import torch.multiprocessing as mp +import torch.distributed.rpc as rpc +from genrl.core import ReplayBuffer + + +class ExperienceServer(Node): + def __init__(self, name, master, size, rank=None): + super(ExperienceServer, self).__init__(name, master, rank) + mp.Process( + target=self.run_paramater_server, + args=(name, master.rref, master.address, master.port, master.world_size, rank, size), + ) + + @staticmethod + def run_paramater_server(name, master_rref, master_address, master_port, world_size, rank, size): + print("Starting Parameter Server") + set_environ(master_address, master_port) + rpc.init_rpc(name=name, world_size=world_size, rank=rank) + buffer = ReplayBuffer(size) + remote_method(Master.store_rref, master_rref, rpc.RRef(buffer), name) + rpc.shutdown() + print("Shutdown experience server") diff --git a/genrl/distributed/learner.py b/genrl/distributed/learner.py new file mode 100644 index 00000000..af2f5c0c --- /dev/null +++ b/genrl/distributed/learner.py @@ -0,0 +1,48 @@ +from genrl.distributed import Master, Node +from genrl.distributed.utils import remote_method, set_environ + +import torch.multiprocessing as mp +import torch.distributed.rpc as rpc + + +class LearnerNode(Node): + def __init__(self, name, master, parameter_server, experience_server, trainer, rank=None): + super(LearnerNode, self).__init__(name, master, rank) + self.parameter_server = parameter_server + self.experience_server = experience_server + + mp.Process( + target=self.run_paramater_server, + args=( + name, + master.rref, + master.address, + master.port, + master.world_size, + self.rank, + parameter_server.rref, + experience_server.rref, + trainer, + ), + ) + + @staticmethod + def train( + name, + master_rref, + master_address, + master_port, + world_size, + rank, + parameter_server_rref, + experience_server_rref, + agent, + trainer, + ): + print("Starting Learner") + set_environ(master_address, master_port) + rpc.init_rpc(name=name, world_size=world_size, rank=rank) + remote_method(Master.store_rref, master_rref, rpc.RRef(trainer), name) + trainer.train_wrapper(parameter_server_rref, experience_server_rref) + rpc.shutdown() + print("Shutdown learner") diff --git a/genrl/distributed/parameter_server.py b/genrl/distributed/parameter_server.py index e69de29b..a6c4b35f 100644 --- a/genrl/distributed/parameter_server.py +++ b/genrl/distributed/parameter_server.py @@ -0,0 +1,37 @@ +from genrl.distributed import Master, Node +from genrl.distributed.utils import remote_method, set_environ + +import torch.multiprocessing as mp +import torch.distributed.rpc as rpc + + +class ParameterServer(Node): + def __init__(self, name, master, init_params, rank=None): + super(ParameterServer, self).__init__(name, master, rank) + mp.Process( + target=self.run_paramater_server, + args=(name, master.rref, master.address, master.port, master.world_size, self.rank, init_params), + ) + + @staticmethod + def run_paramater_server( + name, master_rref, master_address, master_port, world_size, rank, init_params + ): + print("Starting Parameter Server") + set_environ(master_address, master_port) + rpc.init_rpc(name=name, world_size=world_size, rank=rank) + params = init_params + remote_method(Master.store_rref, master_rref, rpc.RRef(params), name) + rpc.shutdown() + print("Shutdown parameter server") + + +class WeightHolder: + def __init__(self, init_weights): + self._weights = init_weights + + def store_weights(self, weights): + self._weights = weights + + def get_weights(self): + return self._weights diff --git a/genrl/distributed/transition_server.py b/genrl/distributed/transition_server.py deleted file mode 100644 index a8f6e256..00000000 --- a/genrl/distributed/transition_server.py +++ /dev/null @@ -1,2 +0,0 @@ -from genrl.core.buffers import ReplayBuffer - diff --git a/genrl/distributed/utils.py b/genrl/distributed/utils.py new file mode 100644 index 00000000..d2a9c3ff --- /dev/null +++ b/genrl/distributed/utils.py @@ -0,0 +1,33 @@ +import torch.distributed.rpc as rpc +import os + + +# --------- Helper Methods -------------------- + +# On the local node, call a method with first arg as the value held by the +# RRef. Other args are passed in as arguments to the function called. +# Useful for calling instance methods. method could be any matching function, including +# class methods. + + +def call_method(method, rref, *args, **kwargs): + return method(rref.local_value(), *args, **kwargs) + + +# Given an RRef, return the result of calling the passed in method on the value +# held by the RRef. This call is done on the remote node that owns +# the RRef and passes along the given argument. +# Example: If the value held by the RRef is of type Foo, then +# remote_method(Foo.bar, rref, arg1, arg2) is equivalent to calling +# .bar(arg1, arg2) on the remote node and getting the result +# back. + + +def remote_method(method, rref, *args, **kwargs): + args = [method, rref] + list(args) + return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs) + + +def set_environ(address, port): + os.environ["MASTER_ADDR"] = str(address) + os.environ["MASTER_PORT"] = str(port) From 73586d53cc5816b5d8fa127663fbb56a47ad1a4d Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Wed, 7 Oct 2020 13:34:20 +0000 Subject: [PATCH 06/27] working structure --- examples/distributed.py | 46 +++++---- genrl/distributed/__init__.py | 2 +- genrl/distributed/actor.py | 72 ++++++------- genrl/distributed/core.py | 135 +++++++++++++++++++------ genrl/distributed/experience_server.py | 25 +++-- genrl/distributed/learner.py | 54 +++++----- genrl/distributed/parameter_server.py | 23 ++--- genrl/distributed/utils.py | 33 ------ 8 files changed, 213 insertions(+), 177 deletions(-) delete mode 100644 genrl/distributed/utils.py diff --git a/examples/distributed.py b/examples/distributed.py index 76002cde..0ef3fa41 100644 --- a/examples/distributed.py +++ b/examples/distributed.py @@ -4,14 +4,18 @@ ParameterServer, ActorNode, LearnerNode, - remote_method, DistributedTrainer, WeightHolder, ) from genrl.core import ReplayBuffer from genrl.agents import DDPG import gym +import argparse +import torch.multiprocessing as mp +parser = argparse.ArgumentParser() +parser.add_argument("-n", type=int) +args = parser.parse_args() N_ACTORS = 2 BUFFER_SIZE = 10 @@ -25,10 +29,11 @@ def collect_experience(agent, experience_server_rref): done = False for i in range(MAX_ENV_STEPS): action = agent.select_action(obs) - next_obs, reward, done, info = agent.env.step() + next_obs, reward, done, info = agent.env.step(action) print("Sending experience") - remote_method(ReplayBuffer.push, experience_server_rref) + experience_server_rref.rpc_sync().push((obs, action, reward, done, next_obs)) print("Done sending experience") + obs = next_obs if done: break @@ -40,30 +45,35 @@ def __init__(self, agent, train_steps, batch_size): self.batch_size = batch_size def train(self, parameter_server_rref, experience_server_rref): + print("IN TRAIN") for i in range(self.train_steps): - batch = remote_method( - ReplayBuffer.sample, parameter_server_rref, self.batch_size - ) + print("GETTING BATCH") + batch = experience_server_rref.rpc_sync().sample(self.batch_size) + print("GOT BATCH") if batch is None: continue self.agent.update_params(batch) print("Storing weights") - remote_method( - WeightHolder.store_weights, - parameter_server_rref, - self.agent.get_weights(), - ) + parameter_server_rref.rpc_sync().store_weights(self.agent.get_weights()) + # remote_method(WeightHolder.store_weights, parameter_server_rref, self.agent.get_weights()) print("Done storing weights") + print(f"TRAINER: {i} STESPS done") -master = Master(world_size=6, address="localhost") -print("inited master") +mp.set_start_method("fork") + +master = Master(world_size=6, address="localhost", port=29504) env = gym.make("Pendulum-v0") -agent = DDPG(env) -parameter_server = ParameterServer("param-0", master, agent.get_weights, rank=1) -experience_server = ExperienceServer("experience-0", master, BUFFER_SIZE, rank=2) +agent = DDPG("mlp", env) +parameter_server = ParameterServer( + "param-0", master, WeightHolder(agent.get_weights()), rank=1 +) +buffer = ReplayBuffer(BUFFER_SIZE) +experience_server = ExperienceServer("experience-0", master, buffer, rank=2) trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE) -learner = LearnerNode("learner-0", master, parameter_server, experience_server, trainer, rank=3) +learner = LearnerNode( + "learner-0", master, parameter_server, experience_server, trainer, rank=3 +) actors = [ ActorNode( f"actor-{i}", @@ -73,7 +83,7 @@ def train(self, parameter_server_rref, experience_server_rref): learner, agent, collect_experience, - rank=i+4 + rank=i + 4, ) for i in range(N_ACTORS) ] diff --git a/genrl/distributed/__init__.py b/genrl/distributed/__init__.py index 54751712..49692a5f 100644 --- a/genrl/distributed/__init__.py +++ b/genrl/distributed/__init__.py @@ -1,4 +1,4 @@ -from genrl.distributed.core import Master, DistributedTrainer, remote_method, Node +from genrl.distributed.core import Master, DistributedTrainer, Node from genrl.distributed.parameter_server import ParameterServer, WeightHolder from genrl.distributed.experience_server import ExperienceServer from genrl.distributed.actor import ActorNode diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py index 2f1deaa6..50bb2125 100644 --- a/genrl/distributed/actor.py +++ b/genrl/distributed/actor.py @@ -1,12 +1,5 @@ -from genrl.distributed.core import ( - DistributedTrainer, - Master, - Node, -) -from genrl.distributed.parameter_server import WeightHolder -from genrl.distributed.utils import remote_method, set_environ - -import torch.multiprocessing as mp +from genrl.distributed.core import Node +from genrl.distributed.core import get_rref, store_rref import torch.distributed.rpc as rpc @@ -20,55 +13,52 @@ def __init__( learner, agent, collect_experience, - rank=None + rank=None, ): super(ActorNode, self).__init__(name, master, rank) self.parameter_server = parameter_server self.experience_server = experience_server - - mp.Process( - target=self.run_paramater_server, - args=( - name, - master.rref, - master.address, - master.port, - master.world_size, - self.rank, - parameter_server.rref, - experience_server.rref, - learner.rref, - agent, - collect_experience, + self.init_proc( + target=self.act, + kwargs=dict( + parameter_server_name=parameter_server.name, + experience_server_name=experience_server.name, + learner_name=learner.name, + agent=agent, + collect_experience=collect_experience, ), ) + self.start_proc() @staticmethod - def train( + def act( name, - master_rref, - master_address, - master_port, world_size, rank, - parameter_server_rref, - experience_server_rref, - learner_rref, + parameter_server_name, + experience_server_name, + learner_name, agent, collect_experience, - + **kwargs, ): - print("Starting Actor") - set_environ(master_address, master_port) rpc.init_rpc(name=name, world_size=world_size, rank=rank) - remote_method(Master.store_rref, master_rref, rpc.RRef(agent), name) - while not remote_method(DistributedTrainer.is_done(), learner_rref): - agent.load_weights( - remote_method(WeightHolder.get_weights(), parameter_server_rref) - ) + print("actor rpc inited") + rref = rpc.RRef(agent) + print(rref) + store_rref(name, rref) + print("stored rref") + parameter_server_rref = get_rref(parameter_server_name) + experience_server_rref = get_rref(experience_server_name) + learner_rref = get_rref(learner_name) + print( + f"{name}: {parameter_server_rref} {experience_server_rref} {learner_rref}" + ) + while not learner_rref.rpc_sync().is_done(): + print(f"{name}: going to load weights!") + agent.load_weights(parameter_server_rref.rpc_sync().get_weights()) print("Done loadiing weights") collect_experience(agent, experience_server_rref) print("Done collecting experience") rpc.shutdown() - print("Shutdown actor") diff --git a/genrl/distributed/core.py b/genrl/distributed/core.py index c89952a9..85d3fe2f 100644 --- a/genrl/distributed/core.py +++ b/genrl/distributed/core.py @@ -1,57 +1,131 @@ import torch.distributed.rpc as rpc -from genrl.distributed.utils import remote_method, set_environ +import threading + +from abc import ABC, abstractmethod +import torch.multiprocessing as mp +import os +import time + +_rref_reg = {} +_global_lock = threading.Lock() + + +def _get_rref(idx): + global _rref_reg + with _global_lock: + if idx in _rref_reg.keys(): + return _rref_reg[idx] + else: + return None + + +def _store_rref(idx, rref): + global _rref_reg + with _global_lock: + _rref_reg[idx] = rref + + +def get_rref(idx): + rref = rpc.rpc_sync("master", _get_rref, args=(idx,)) + while rref is None: + time.sleep(0.5) + rref = rpc.rpc_sync("master", _get_rref, args=(idx,)) + return rref + + +def store_rref(idx, rref): + rpc.rpc_sync("master", _store_rref, args=(idx, rref)) + + +def set_environ(address, port): + os.environ["MASTER_ADDR"] = str(address) + os.environ["MASTER_PORT"] = str(port) class Node: def __init__(self, name, master, rank): self._name = name self.master = master - self.master.increment_node_count() if rank is None: self._rank = master.node_count - elif rank > 0 and rank < master.world_size: + elif rank >= 0 and rank < master.world_size: self._rank = rank - elif rank == 0: - raise ValueError("Rank of 0 is invalid for node") elif rank >= master.world_size: raise ValueError("Specified rank greater than allowed by world size") else: raise ValueError("Invalid value of rank") + self.p = None + + def __del__(self): + if self.p is None: + raise RuntimeWarning( + "Removing node when process was not initialised properly" + ) + else: + self.p.join() + + @staticmethod + def _target_wrapper(target, **kwargs): + pid = os.getpid() + print(f"Starting {kwargs['name']} with pid {pid}") + set_environ(kwargs["master_address"], kwargs["master_port"]) + target(**kwargs) + print(f"Shutdown {kwargs['name']} with pid {pid}") + + def init_proc(self, target, kwargs): + kwargs.update( + dict( + name=self.name, + master_address=self.master.address, + master_port=self.master.port, + world_size=self.master.world_size, + rank=self.rank, + ) + ) + self.p = mp.Process(target=self._target_wrapper, args=(target,), kwargs=kwargs) + + def start_proc(self): + if self.p is None: + raise RuntimeError("Trying to start uninitialised process") + self.p.start() + + @property + def name(self): + return self._name @property def rref(self): - return self.master[self.name] + return get_rref(self.name) @property def rank(self): return self._rank -class Master(Node): - def __init__(self, world_size, address="localhost", port=29500): - print("initing master") +def _run_master(world_size): + print(f"Starting master at {os.getpid()}") + rpc.init_rpc("master", rank=0, world_size=world_size) + rpc.shutdown() + + +class Master: + def __init__(self, world_size, address="localhost", port=29501): set_environ(address, port) - print("initing master") - rpc.init_rpc(name="master", rank=0, world_size=world_size) - print("initing master") self._world_size = world_size - self._rref = rpc.RRef(self) - print("initing master") - self._rref_reg = {} self._address = address self._port = port self._node_counter = 0 - - def store_rref(self, parent_rref, idx): - self._rref_reg[idx] = parent_rref - - def fetch_rref(self, idx): - return self._rref_reg[idx] - - @property - def rref(self): - return self._rref + self.p = mp.Process(target=_run_master, args=(world_size,)) + self.p.start() + + def __del__(self): + if self.p is None: + raise RuntimeWarning( + "Shutting down master when it was not initialised properly" + ) + else: + self.p.join() @property def world_size(self): @@ -69,22 +143,19 @@ def port(self): def node_count(self): return self._node_counter - def increment_node_counter(self): - self._node_counter += 1 - if self.node_count >= self.world_size: - raise Exception("Attempt made to add more nodes than specified by world size") - -class DistributedTrainer: +class DistributedTrainer(ABC): def __init__(self, agent): self.agent = agent self._completed_training_flag = False + @abstractmethod def train(self, parameter_server_rref, experience_server_rref): - raise NotImplementedError + pass def train_wrapper(self, parameter_server_rref, experience_server_rref): self._completed_training_flag = False + print("TRAINER: CALLING TRAIN") self.train(parameter_server_rref, experience_server_rref) self._completed_training_flag = True diff --git a/genrl/distributed/experience_server.py b/genrl/distributed/experience_server.py index 5a524e92..08937298 100644 --- a/genrl/distributed/experience_server.py +++ b/genrl/distributed/experience_server.py @@ -1,25 +1,24 @@ -from genrl.distributed import Master, Node -from genrl.distributed.utils import remote_method, set_environ +from genrl.distributed import Node +from genrl.distributed.core import store_rref -import torch.multiprocessing as mp import torch.distributed.rpc as rpc -from genrl.core import ReplayBuffer class ExperienceServer(Node): - def __init__(self, name, master, size, rank=None): + def __init__(self, name, master, buffer, rank=None): super(ExperienceServer, self).__init__(name, master, rank) - mp.Process( + self.init_proc( target=self.run_paramater_server, - args=(name, master.rref, master.address, master.port, master.world_size, rank, size), + kwargs=dict(buffer=buffer), ) + self.start_proc() @staticmethod - def run_paramater_server(name, master_rref, master_address, master_port, world_size, rank, size): - print("Starting Parameter Server") - set_environ(master_address, master_port) + def run_paramater_server(name, world_size, rank, buffer, **kwargs): rpc.init_rpc(name=name, world_size=world_size, rank=rank) - buffer = ReplayBuffer(size) - remote_method(Master.store_rref, master_rref, rpc.RRef(buffer), name) + print("inited exp rpcs") + rref = rpc.RRef(buffer) + print(rref) + store_rref(name, rref) + print("serving buffer") rpc.shutdown() - print("Shutdown experience server") diff --git a/genrl/distributed/learner.py b/genrl/distributed/learner.py index af2f5c0c..383a4501 100644 --- a/genrl/distributed/learner.py +++ b/genrl/distributed/learner.py @@ -1,48 +1,48 @@ -from genrl.distributed import Master, Node -from genrl.distributed.utils import remote_method, set_environ +from genrl.distributed import Node +from genrl.distributed.core import get_rref, store_rref -import torch.multiprocessing as mp import torch.distributed.rpc as rpc class LearnerNode(Node): - def __init__(self, name, master, parameter_server, experience_server, trainer, rank=None): + def __init__( + self, name, master, parameter_server, experience_server, trainer, rank=None + ): super(LearnerNode, self).__init__(name, master, rank) self.parameter_server = parameter_server self.experience_server = experience_server - mp.Process( - target=self.run_paramater_server, - args=( - name, - master.rref, - master.address, - master.port, - master.world_size, - self.rank, - parameter_server.rref, - experience_server.rref, - trainer, + self.init_proc( + target=self.learn, + kwargs=dict( + parameter_server_name=self.parameter_server.name, + experience_server_name=self.experience_server.name, + trainer=trainer, ), ) + self.start_proc() @staticmethod - def train( + def learn( name, - master_rref, - master_address, - master_port, world_size, rank, - parameter_server_rref, - experience_server_rref, - agent, + parameter_server_name, + experience_server_name, trainer, + **kwargs, ): - print("Starting Learner") - set_environ(master_address, master_port) rpc.init_rpc(name=name, world_size=world_size, rank=rank) - remote_method(Master.store_rref, master_rref, rpc.RRef(trainer), name) + print("inited trainer rpc") + rref = rpc.RRef(trainer) + print(rref) + store_rref(name, rref) + print("starting to train") + parameter_server_rref = get_rref(parameter_server_name) + experience_server_rref = get_rref( + experience_server_name, + ) + print(f"{name}: {parameter_server_rref} {experience_server_rref}") + print("TRAINER: CALLING WRAPPER") trainer.train_wrapper(parameter_server_rref, experience_server_rref) rpc.shutdown() - print("Shutdown learner") diff --git a/genrl/distributed/parameter_server.py b/genrl/distributed/parameter_server.py index a6c4b35f..e590faf2 100644 --- a/genrl/distributed/parameter_server.py +++ b/genrl/distributed/parameter_server.py @@ -1,29 +1,28 @@ -from genrl.distributed import Master, Node -from genrl.distributed.utils import remote_method, set_environ +from genrl.distributed import Node +from genrl.distributed.core import store_rref -import torch.multiprocessing as mp import torch.distributed.rpc as rpc class ParameterServer(Node): def __init__(self, name, master, init_params, rank=None): super(ParameterServer, self).__init__(name, master, rank) - mp.Process( + self.init_proc( target=self.run_paramater_server, - args=(name, master.rref, master.address, master.port, master.world_size, self.rank, init_params), + kwargs=dict(init_params=init_params), ) + self.start_proc() @staticmethod - def run_paramater_server( - name, master_rref, master_address, master_port, world_size, rank, init_params - ): - print("Starting Parameter Server") - set_environ(master_address, master_port) + def run_paramater_server(name, world_size, rank, init_params, **kwargs): rpc.init_rpc(name=name, world_size=world_size, rank=rank) + print("inited param server rpc") params = init_params - remote_method(Master.store_rref, master_rref, rpc.RRef(params), name) + rref = rpc.RRef(params) + print(rref) + store_rref(name, rref) + print("serving params") rpc.shutdown() - print("Shutdown parameter server") class WeightHolder: diff --git a/genrl/distributed/utils.py b/genrl/distributed/utils.py deleted file mode 100644 index d2a9c3ff..00000000 --- a/genrl/distributed/utils.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch.distributed.rpc as rpc -import os - - -# --------- Helper Methods -------------------- - -# On the local node, call a method with first arg as the value held by the -# RRef. Other args are passed in as arguments to the function called. -# Useful for calling instance methods. method could be any matching function, including -# class methods. - - -def call_method(method, rref, *args, **kwargs): - return method(rref.local_value(), *args, **kwargs) - - -# Given an RRef, return the result of calling the passed in method on the value -# held by the RRef. This call is done on the remote node that owns -# the RRef and passes along the given argument. -# Example: If the value held by the RRef is of type Foo, then -# remote_method(Foo.bar, rref, arg1, arg2) is equivalent to calling -# .bar(arg1, arg2) on the remote node and getting the result -# back. - - -def remote_method(method, rref, *args, **kwargs): - args = [method, rref] + list(args) - return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs) - - -def set_environ(address, port): - os.environ["MASTER_ADDR"] = str(address) - os.environ["MASTER_PORT"] = str(port) From 9ef6845b0f2d67110ea5c7b4422c27846a81f1d2 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Wed, 7 Oct 2020 14:02:54 +0000 Subject: [PATCH 07/27] fixed integration bugs --- examples/distributed.py | 24 +- genrl/agents/deep/base/offpolicy.py | 13 +- genrl/agents/deep/ddpg/ddpg.py | 4 +- genrl/distributed/__init__.py | 2 +- genrl/distributed/actor.py | 11 +- genrl/distributed/core.py | 23 +- genrl/distributed/experience_server.py | 5 +- genrl/distributed/learner.py | 7 +- genrl/distributed/parameter_server.py | 5 +- genrl/trainers/__init__.py | 1 + genrl/trainers/distributed.py | 367 +++++++++++++------------ 11 files changed, 219 insertions(+), 243 deletions(-) diff --git a/examples/distributed.py b/examples/distributed.py index 0ef3fa41..03a84149 100644 --- a/examples/distributed.py +++ b/examples/distributed.py @@ -4,11 +4,11 @@ ParameterServer, ActorNode, LearnerNode, - DistributedTrainer, WeightHolder, ) from genrl.core import ReplayBuffer from genrl.agents import DDPG +from genrl.trainers import DistributedTrainer import gym import argparse import torch.multiprocessing as mp @@ -21,7 +21,7 @@ BUFFER_SIZE = 10 MAX_ENV_STEPS = 100 TRAIN_STEPS = 10 -BATCH_SIZE = 5 +BATCH_SIZE = 1 def collect_experience(agent, experience_server_rref): @@ -30,9 +30,7 @@ def collect_experience(agent, experience_server_rref): for i in range(MAX_ENV_STEPS): action = agent.select_action(obs) next_obs, reward, done, info = agent.env.step(action) - print("Sending experience") - experience_server_rref.rpc_sync().push((obs, action, reward, done, next_obs)) - print("Done sending experience") + experience_server_rref.rpc_sync().push((obs, action, reward, next_obs, done)) obs = next_obs if done: break @@ -45,24 +43,20 @@ def __init__(self, agent, train_steps, batch_size): self.batch_size = batch_size def train(self, parameter_server_rref, experience_server_rref): - print("IN TRAIN") - for i in range(self.train_steps): - print("GETTING BATCH") + i = 0 + while i < self.train_steps: batch = experience_server_rref.rpc_sync().sample(self.batch_size) - print("GOT BATCH") if batch is None: continue - self.agent.update_params(batch) - print("Storing weights") + self.agent.update_params(batch, 1) parameter_server_rref.rpc_sync().store_weights(self.agent.get_weights()) - # remote_method(WeightHolder.store_weights, parameter_server_rref, self.agent.get_weights()) - print("Done storing weights") - print(f"TRAINER: {i} STESPS done") + print(f"Trainer: {i + 1} / {self.train_steps} steps completed") + i += 1 mp.set_start_method("fork") -master = Master(world_size=6, address="localhost", port=29504) +master = Master(world_size=6, address="localhost", port=29500) env = gym.make("Pendulum-v0") agent = DDPG("mlp", env) parameter_server = ParameterServer( diff --git a/genrl/agents/deep/base/offpolicy.py b/genrl/agents/deep/base/offpolicy.py index 459ed58d..660c87c4 100644 --- a/genrl/agents/deep/base/offpolicy.py +++ b/genrl/agents/deep/base/offpolicy.py @@ -80,7 +80,7 @@ def _reshape_batch(self, batch: List): """ return [*batch] - def sample_from_buffer(self, beta: float = None): + def sample_from_buffer(self, beta: float = None, batch = None): """Samples experiences from the buffer and converts them into usable formats Args: @@ -89,11 +89,12 @@ def sample_from_buffer(self, beta: float = None): Returns: batch (:obj:`list`): Replay experiences sampled from the buffer """ - # Samples from the buffer - if beta is not None: - batch = self.replay_buffer.sample(self.batch_size, beta=beta) - else: - batch = self.replay_buffer.sample(self.batch_size) + if batch is None: + # Samples from the buffer + if beta is not None: + batch = self.replay_buffer.sample(self.batch_size, beta=beta) + else: + batch = self.replay_buffer.sample(self.batch_size) states, actions, rewards, next_states, dones = self._reshape_batch(batch) diff --git a/genrl/agents/deep/ddpg/ddpg.py b/genrl/agents/deep/ddpg/ddpg.py index a1122299..9ed54a02 100644 --- a/genrl/agents/deep/ddpg/ddpg.py +++ b/genrl/agents/deep/ddpg/ddpg.py @@ -79,14 +79,14 @@ def _create_model(self) -> None: self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy) self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value) - def update_params(self, update_interval: int) -> None: + def update_params(self, batch, update_interval: int) -> None: """Update parameters of the model Args: update_interval (int): Interval between successive updates of the target model """ for timestep in range(update_interval): - batch = self.sample_from_buffer() + batch = self.sample_from_buffer(batch=batch) value_loss = self.get_q_loss(batch) self.logs["value_loss"].append(value_loss.item()) diff --git a/genrl/distributed/__init__.py b/genrl/distributed/__init__.py index 49692a5f..c3276db1 100644 --- a/genrl/distributed/__init__.py +++ b/genrl/distributed/__init__.py @@ -1,4 +1,4 @@ -from genrl.distributed.core import Master, DistributedTrainer, Node +from genrl.distributed.core import Master, Node from genrl.distributed.parameter_server import ParameterServer, WeightHolder from genrl.distributed.experience_server import ExperienceServer from genrl.distributed.actor import ActorNode diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py index 50bb2125..e64c3459 100644 --- a/genrl/distributed/actor.py +++ b/genrl/distributed/actor.py @@ -43,22 +43,15 @@ def act( **kwargs, ): rpc.init_rpc(name=name, world_size=world_size, rank=rank) - print("actor rpc inited") + print(f"{name}: RPC Initialised") rref = rpc.RRef(agent) - print(rref) store_rref(name, rref) - print("stored rref") parameter_server_rref = get_rref(parameter_server_name) experience_server_rref = get_rref(experience_server_name) learner_rref = get_rref(learner_name) - print( - f"{name}: {parameter_server_rref} {experience_server_rref} {learner_rref}" - ) + print(f"{name}: Begining experience collection") while not learner_rref.rpc_sync().is_done(): - print(f"{name}: going to load weights!") agent.load_weights(parameter_server_rref.rpc_sync().get_weights()) - print("Done loadiing weights") collect_experience(agent, experience_server_rref) - print("Done collecting experience") rpc.shutdown() diff --git a/genrl/distributed/core.py b/genrl/distributed/core.py index 85d3fe2f..a3fe7c66 100644 --- a/genrl/distributed/core.py +++ b/genrl/distributed/core.py @@ -23,6 +23,10 @@ def _get_rref(idx): def _store_rref(idx, rref): global _rref_reg with _global_lock: + if idx in _rref_reg.keys(): + raise Warning( + f"Re-assigning RRef for key: {idx}. Make sure you are not using duplicate names for nodes" + ) _rref_reg[idx] = rref @@ -142,22 +146,3 @@ def port(self): @property def node_count(self): return self._node_counter - - -class DistributedTrainer(ABC): - def __init__(self, agent): - self.agent = agent - self._completed_training_flag = False - - @abstractmethod - def train(self, parameter_server_rref, experience_server_rref): - pass - - def train_wrapper(self, parameter_server_rref, experience_server_rref): - self._completed_training_flag = False - print("TRAINER: CALLING TRAIN") - self.train(parameter_server_rref, experience_server_rref) - self._completed_training_flag = True - - def is_done(self): - return self._completed_training_flag diff --git a/genrl/distributed/experience_server.py b/genrl/distributed/experience_server.py index 08937298..95335569 100644 --- a/genrl/distributed/experience_server.py +++ b/genrl/distributed/experience_server.py @@ -16,9 +16,8 @@ def __init__(self, name, master, buffer, rank=None): @staticmethod def run_paramater_server(name, world_size, rank, buffer, **kwargs): rpc.init_rpc(name=name, world_size=world_size, rank=rank) - print("inited exp rpcs") + print(f"{name}: Initialised RPC") rref = rpc.RRef(buffer) - print(rref) store_rref(name, rref) - print("serving buffer") + print(f"{name}: Serving experience buffer") rpc.shutdown() diff --git a/genrl/distributed/learner.py b/genrl/distributed/learner.py index 383a4501..541e0125 100644 --- a/genrl/distributed/learner.py +++ b/genrl/distributed/learner.py @@ -33,16 +33,13 @@ def learn( **kwargs, ): rpc.init_rpc(name=name, world_size=world_size, rank=rank) - print("inited trainer rpc") + print(f"{name}: Initialised RPC") rref = rpc.RRef(trainer) - print(rref) store_rref(name, rref) - print("starting to train") parameter_server_rref = get_rref(parameter_server_name) experience_server_rref = get_rref( experience_server_name, ) - print(f"{name}: {parameter_server_rref} {experience_server_rref}") - print("TRAINER: CALLING WRAPPER") + print(f"{name}: Beginning training") trainer.train_wrapper(parameter_server_rref, experience_server_rref) rpc.shutdown() diff --git a/genrl/distributed/parameter_server.py b/genrl/distributed/parameter_server.py index e590faf2..ae5ec805 100644 --- a/genrl/distributed/parameter_server.py +++ b/genrl/distributed/parameter_server.py @@ -16,12 +16,11 @@ def __init__(self, name, master, init_params, rank=None): @staticmethod def run_paramater_server(name, world_size, rank, init_params, **kwargs): rpc.init_rpc(name=name, world_size=world_size, rank=rank) - print("inited param server rpc") + print(f"{name}: Initialised RPC") params = init_params rref = rpc.RRef(params) - print(rref) store_rref(name, rref) - print("serving params") + print(f"{name}: Serving parameters") rpc.shutdown() diff --git a/genrl/trainers/__init__.py b/genrl/trainers/__init__.py index 7410831b..c5448cc3 100644 --- a/genrl/trainers/__init__.py +++ b/genrl/trainers/__init__.py @@ -3,3 +3,4 @@ from genrl.trainers.classical import ClassicalTrainer # noqa from genrl.trainers.offpolicy import OffPolicyTrainer # noqa from genrl.trainers.onpolicy import OnPolicyTrainer # noqa +from genrl.trainers.distributed import DistributedTrainer # noqa diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py index a768820b..6f14e2e9 100644 --- a/genrl/trainers/distributed.py +++ b/genrl/trainers/distributed.py @@ -1,183 +1,190 @@ -import copy -import multiprocessing as mp -import threading -from typing import Type, Union +from abc import ABC, abstractmethod -import gym -import numpy as np -import reverb -import tensorflow as tf -import torch -from genrl.trainers import Trainer - - -class DistributedOffPolicyTrainer: - """Distributed Off Policy Trainer Class - - Trainer class for Distributed Off Policy Agents - - """ - - def __init__( - self, - agent, - env, - buffer_server_port=None, - param_server_port=None, - **kwargs, - ): - self.env = env +class DistributedTrainer(ABC): + def __init__(self, agent): self.agent = agent - self.buffer_server_port = buffer_server_port - self.param_server_port = param_server_port - - def train( - self, n_actors, max_buffer_size, batch_size, max_updates, update_interval - ): - buffer_server = reverb.Server( - tables=[ - reverb.Table( - name="replay_buffer", - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=max_buffer_size, - rate_limiter=reverb.rate_limiters.MinSize(1), - ) - ], - port=self.buffer_server_port, - ) - buffer_server_address = f"localhost:{buffer_server.port}" - - param_server = reverb.Server( - tables=[ - reverb.Table( - name="param_buffer", - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=1, - rate_limiter=reverb.rate_limiters.MinSize(1), - ) - ], - port=self.param_server_port, - ) - param_server_address = f"localhost:{param_server.port}" - - actor_procs = [] - for _ in range(n_actors): - p = threading.Thread( - target=run_actor, - args=( - copy.deepcopy(self.agent), - copy.deepcopy(self.env), - buffer_server_address, - param_server_address, - ), - daemon=True, - ) - p.start() - actor_procs.append(p) - - learner_proc = threading.Thread( - target=run_learner, - args=( - copy.deepcopy(self.agent), - max_updates, - update_interval, - buffer_server_address, - param_server_address, - batch_size, - ), - daemon=True, - ) - learner_proc.daemon = True - learner_proc.start() - learner_proc.join() - - # param_client = reverb.Client(param_server_address) - # self.agent.replay_buffer = ReverbReplayDataset( - # self.agent.env, buffer_server_address, batch_size - # ) - - # for _ in range(max_updates): - # self.agent.update_params(update_interval) - # params = self.agent.get_weights() - # param_client.insert(params.values(), {"param_buffer": 1}) - # print("weights updated") - # # print(list(param_client.sample("param_buffer"))) - - -def run_actor(agent, env, buffer_server_address, param_server_address): - buffer_client = reverb.Client(buffer_server_address) - param_client = reverb.TFClient(param_server_address) - - state = env.reset().astype(np.float32) - - for i in range(10): - # params = param_client.sample("param_buffer", []) - # print("Sampling done") - # print(list(params)) - # agent.load_weights(params) - - action = agent.select_action(state).numpy() - next_state, reward, done, _ = env.step(action) - next_state = next_state.astype(np.float32) - reward = np.array([reward]).astype(np.float32) - done = np.array([done]).astype(np.bool) - - buffer_client.insert([state, action, reward, next_state, done], {"replay_buffer": 1}) - print("transition inserted") - state = env.reset().astype(np.float32) if done else next_state.copy() - - -def run_learner( - agent, - max_updates, - update_interval, - buffer_server_address, - param_server_address, - batch_size, -): - param_client = reverb.Client(param_server_address) - agent.replay_buffer = ReverbReplayDataset( - agent.env, buffer_server_address, batch_size - ) - for _ in range(max_updates): - agent.update_params(update_interval) - params = agent.get_weights() - param_client.insert(params.values(), {"param_buffer": 1}) - print("weights updated") - # print(list(param_client.sample("param_buffer"))) - - -class ReverbReplayDataset: - def __init__(self, env, address, batch_size): - action_dtype = ( - np.int64 - if isinstance(env.action_space, gym.spaces.discrete.Discrete) - else np.float32 - ) - obs_shape = env.observation_space.shape - action_shape = env.action_space.shape - reward_shape = 1 - done_shape = 1 - - self._dataset = reverb.ReplayDataset( - server_address=address, - table="replay_buffer", - max_in_flight_samples_per_worker=2 * batch_size, - dtypes=(np.float32, action_dtype, np.float32, np.float32, np.bool), - shapes=( - tf.TensorShape(obs_shape), - tf.TensorShape(action_shape), - tf.TensorShape(reward_shape), - tf.TensorShape(obs_shape), - tf.TensorShape(done_shape), - ), - ) - self._data_iter = self._dataset.batch(batch_size).as_numpy_iterator() - - def sample(self, *args, **kwargs): - sample = next(self._data_iter) - obs, a, r, next_obs, d = [torch.from_numpy(t).float() for t in sample.data] - return obs, a, r, next_obs, d + self._completed_training_flag = False + + @abstractmethod + def train(self, parameter_server_rref, experience_server_rref): + pass + + def train_wrapper(self, parameter_server_rref, experience_server_rref): + self._completed_training_flag = False + self.train(parameter_server_rref, experience_server_rref) + self._completed_training_flag = True + + def is_done(self): + return self._completed_training_flag + + +# class DistributedOffPolicyTrainer: +# """Distributed Off Policy Trainer Class + +# Trainer class for Distributed Off Policy Agents + +# """ + +# def __init__( +# self, +# agent, +# env, +# buffer_server_port=None, +# param_server_port=None, +# **kwargs, +# ): +# self.env = env +# self.agent = agent +# self.buffer_server_port = buffer_server_port +# self.param_server_port = param_server_port + +# def train( +# self, n_actors, max_buffer_size, batch_size, max_updates, update_interval +# ): +# buffer_server = reverb.Server( +# tables=[ +# reverb.Table( +# name="replay_buffer", +# sampler=reverb.selectors.Uniform(), +# remover=reverb.selectors.Fifo(), +# max_size=max_buffer_size, +# rate_limiter=reverb.rate_limiters.MinSize(1), +# ) +# ], +# port=self.buffer_server_port, +# ) +# buffer_server_address = f"localhost:{buffer_server.port}" + +# param_server = reverb.Server( +# tables=[ +# reverb.Table( +# name="param_buffer", +# sampler=reverb.selectors.Uniform(), +# remover=reverb.selectors.Fifo(), +# max_size=1, +# rate_limiter=reverb.rate_limiters.MinSize(1), +# ) +# ], +# port=self.param_server_port, +# ) +# param_server_address = f"localhost:{param_server.port}" + +# actor_procs = [] +# for _ in range(n_actors): +# p = threading.Thread( +# target=run_actor, +# args=( +# copy.deepcopy(self.agent), +# copy.deepcopy(self.env), +# buffer_server_address, +# param_server_address, +# ), +# daemon=True, +# ) +# p.start() +# actor_procs.append(p) + +# learner_proc = threading.Thread( +# target=run_learner, +# args=( +# copy.deepcopy(self.agent), +# max_updates, +# update_interval, +# buffer_server_address, +# param_server_address, +# batch_size, +# ), +# daemon=True, +# ) +# learner_proc.daemon = True +# learner_proc.start() +# learner_proc.join() + +# # param_client = reverb.Client(param_server_address) +# # self.agent.replay_buffer = ReverbReplayDataset( +# # self.agent.env, buffer_server_address, batch_size +# # ) + +# # for _ in range(max_updates): +# # self.agent.update_params(update_interval) +# # params = self.agent.get_weights() +# # param_client.insert(params.values(), {"param_buffer": 1}) +# # print("weights updated") +# # # print(list(param_client.sample("param_buffer"))) + + +# def run_actor(agent, env, buffer_server_address, param_server_address): +# buffer_client = reverb.Client(buffer_server_address) +# param_client = reverb.TFClient(param_server_address) + +# state = env.reset().astype(np.float32) + +# for i in range(10): +# # params = param_client.sample("param_buffer", []) +# # print("Sampling done") +# # print(list(params)) +# # agent.load_weights(params) + +# action = agent.select_action(state).numpy() +# next_state, reward, done, _ = env.step(action) +# next_state = next_state.astype(np.float32) +# reward = np.array([reward]).astype(np.float32) +# done = np.array([done]).astype(np.bool) + +# buffer_client.insert([state, action, reward, next_state, done], {"replay_buffer": 1}) +# print("transition inserted") +# state = env.reset().astype(np.float32) if done else next_state.copy() + + +# def run_learner( +# agent, +# max_updates, +# update_interval, +# buffer_server_address, +# param_server_address, +# batch_size, +# ): +# param_client = reverb.Client(param_server_address) +# agent.replay_buffer = ReverbReplayDataset( +# agent.env, buffer_server_address, batch_size +# ) +# for _ in range(max_updates): +# agent.update_params(update_interval) +# params = agent.get_weights() +# param_client.insert(params.values(), {"param_buffer": 1}) +# print("weights updated") +# # print(list(param_client.sample("param_buffer"))) + + +# class ReverbReplayDataset: +# def __init__(self, env, address, batch_size): +# action_dtype = ( +# np.int64 +# if isinstance(env.action_space, gym.spaces.discrete.Discrete) +# else np.float32 +# ) +# obs_shape = env.observation_space.shape +# action_shape = env.action_space.shape +# reward_shape = 1 +# done_shape = 1 + +# self._dataset = reverb.ReplayDataset( +# server_address=address, +# table="replay_buffer", +# max_in_flight_samples_per_worker=2 * batch_size, +# dtypes=(np.float32, action_dtype, np.float32, np.float32, np.bool), +# shapes=( +# tf.TensorShape(obs_shape), +# tf.TensorShape(action_shape), +# tf.TensorShape(reward_shape), +# tf.TensorShape(obs_shape), +# tf.TensorShape(done_shape), +# ), +# ) +# self._data_iter = self._dataset.batch(batch_size).as_numpy_iterator() + +# def sample(self, *args, **kwargs): +# sample = next(self._data_iter) +# obs, a, r, next_obs, d = [torch.from_numpy(t).float() for t in sample.data] +# return obs, a, r, next_obs, d From 072d545dc77df543b4ae2cd737c8fde73a23b0c0 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Wed, 7 Oct 2020 14:03:29 +0000 Subject: [PATCH 08/27] removed unneccary files --- examples/distributed_old_1.py | 76 ---------- examples/distributed_old_2.py | 275 ---------------------------------- 2 files changed, 351 deletions(-) delete mode 100644 examples/distributed_old_1.py delete mode 100644 examples/distributed_old_2.py diff --git a/examples/distributed_old_1.py b/examples/distributed_old_1.py deleted file mode 100644 index ead8c8c6..00000000 --- a/examples/distributed_old_1.py +++ /dev/null @@ -1,76 +0,0 @@ -from genrl.agents import DDPG -from genrl.trainers import OffPolicyTrainer -from genrl.trainers.distributed import DistributedOffPolicyTrainer -from genrl.environments import VectorEnv -import gym -import reverb -import numpy as np -import multiprocessing as mp -import threading - - -# env = VectorEnv("Pendulum-v0") -# agent = DDPG("mlp", env) -# trainer = OffPolicyTrainer(agent, env) -# trainer.train() - - -env = gym.make("Pendulum-v0") -agent = DDPG("mlp", env) - -# o = env.reset() -# action = agent.select_action(o) -# next_state, reward, done, info = env.step(action.numpy()) - -# buffer_server = reverb.Server( -# tables=[ -# reverb.Table( -# name="replay_buffer", -# sampler=reverb.selectors.Uniform(), -# remover=reverb.selectors.Fifo(), -# max_size=10, -# rate_limiter=reverb.rate_limiters.MinSize(4), -# ) -# ], -# port=None, -# ) -# client = reverb.Client(f"localhost:{buffer_server.port}") -# print(client.server_info()) - -# state = env.reset() -# action = agent.select_action(state) -# next_state, reward, done, info = env.step(action.numpy()) - -# state = next_state.copy() -# print(client.server_info()) -# print("going to insert") -# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) -# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) -# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) -# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) -# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1}) -# print("inserted") - -# # print(list(client.sample('replay_buffer', num_samples=1))) - - -# def sample(address): -# print("-- entered proc") -# client = reverb.Client(address) -# print("-- started client") -# print(list(client.sample('replay_buffer', num_samples=1))) - -# a = f"localhost:{buffer_server.port}" -# print("create process") -# # p = mp.Process(target=sample, args=(a,)) -# p = threading.Thread(target=sample, args=(a,)) -# print("start process") -# p.start() -# print("wait process") -# p.join() -# print("end process") - -trainer = DistributedOffPolicyTrainer(agent, env) -trainer.train( - n_actors=2, max_buffer_size=100, batch_size=4, max_updates=10, update_interval=1 -) diff --git a/examples/distributed_old_2.py b/examples/distributed_old_2.py deleted file mode 100644 index 5abf7a95..00000000 --- a/examples/distributed_old_2.py +++ /dev/null @@ -1,275 +0,0 @@ -from genrl.core.buffers import ReplayBuffer -import os - -from genrl.agents import DDPG -import torch -import torch.distributed.rpc as rpc -import torch.multiprocessing as mp -import torch.nn as nn -import torch.nn.functional as F -from torch import optim -import argparse - -import copy - -import gym -import numpy as np - -os.environ["MASTER_ADDR"] = "localhost" -os.environ["MASTER_PORT"] = "29500" - -# to call a function on an rref, we could do the following -# _remote_method(some_func, rref, *args) - -def _call_method(method, rref, *args, **kwargs): - return method(rref.local_value(), *args, **kwargs) - - -def _remote_method(method, rref, *args, **kwargs): - args = [method, rref] + list(args) - return rpc.rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs) - - -gloabl_lock = mp.Lock() - - -class ParamServer: - def __init__(self, init_params): - self.params = init_params - # self.lock = mp.Lock() - - def store_params(self, new_params): - # with self.lock: - with gloabl_lock: - self.params = new_params - - def get_params(self): - # with self.lock: - with gloabl_lock: - return self.params - - -class DistributedReplayBuffer: - def __init__(self, size): - self.size = size - self.len = 0 - self._buffer = ReplayBuffer(self.size) - - -class DistributedOffPolicyTrainer: - """Distributed Off Policy Trainer Class - - Trainer class for Distributed Off Policy Agents - - """ - def __init__( - self, - agent, - env, - **kwargs, - ): - self.env = env - self.agent = agent - - def train( - self, n_actors, max_buffer_size, batch_size, max_updates, update_interval - ): - - print("a") - world_size = n_actors + 2 - completed = mp.Value("i", 0) - print("a") - param_server_rref_q = mp.Queue(1) - param_server_p = mp.Process( - target=run_param_server, args=(param_server_rref_q, world_size,) - ) - param_server_p.start() - param_server_rref = param_server_rref_q.get() - param_server_rref_q.close() - - print("a") - buffer_rref_q = mp.Queue(1) - buffer_p = mp.Process(target=run_buffer, args=(max_buffer_size, buffer_rref_q, world_size,)) - buffer_p.start() - buffer_rref = buffer_rref_q.get() - buffer_rref_q.close() - print("a") - - actor_ps = [] - for i in range(n_actors): - a_p = mp.Process( - target=run_actor, - args=( - i, - copy.deepcopy(self.agent), - copy.deepcopy(self.env), - param_server_rref,~ - buffer_rref, - world_size, - completed - ), - ) - a_p.start() - actor_ps.append(a_p) - - learner_p = mp.Process( - target=run_learner, - args=(max_updates, batch_size, self.agent, param_server_rref, buffer_rref, world_size, completed), - ) - learner_p.start() - - learner_p.join() - for a in actor_ps: - a.join() - buffer_p.join() - param_server_p.join() - - -def run_param_server(q, world_size): - print("Running parameter server") - rpc.init_rpc(name="param_server", rank=0, world_size=world_size) - print("d") - param_server = ParamServer(None) - param_server_rref = rpc.RRef(param_server) - q.put(param_server_rref) - rpc.shutdown() - print("param server shutting down") - - -def run_buffer(max_buffer_size, q, world_size): - print("Running buffer server") - rpc.init_rpc(name="buffer", rank=1, world_size=world_size) - buffer = ReplayBuffer(max_buffer_size) - buffer_rref = rpc.RRef(buffer) - q.put(buffer_rref) - rpc.shutdown() - print("buffer shutting down") - - -def run_learner(max_updates, batch_size, agent, param_server_rref, buffer_rref, world_size, completed): - print("Running learner") - rpc.init_rpc(name="learner", rank=world_size - 1, world_size=world_size) - i = 0 - while i < max_updates: - batch = _remote_method(ReplayBuffer.sample, buffer_rref, batch_size) - if batch is None: - continue - agent.update_params(batch) - _remote_method(ParamServer.store_params, param_server_rref, agent.get_weights()) - print("weights updated") - i += 1 - print(i) - completed.value = 1 - rpc.shutdown() - print("learner shutting down") - - -def run_actor(i, agent, env, param_server_rref, buffer_rref, world_size, completed): - print(f"Running actor {i}") - - rpc.init_rpc(name=f"action_{i}", rank=i + 1, world_size=world_size) - - state = env.reset().astype(np.float32) - - while not completed.value == 1: - params = _remote_method(ParamServer.get_params, param_server_rref) - agent.load_weights(params) - - action = agent.select_action(state).numpy() - next_state, reward, done, _ = env.step(action) - next_state = next_state.astype(np.float32) - reward = np.array([reward]).astype(np.float32) - done = np.array([done]).astype(np.bool) - - print("attempting to insert transition") - _remote_method(ReplayBuffer.push, buffer_rref, [state, action, reward, next_state, done]) - print("inserted transition") - state = env.reset().astype(np.float32) if done else next_state.copy() - - rpc.shutdown() - print("actor shutting down") - -env = gym.make("Pendulum-v0") -agent = DDPG("mlp", env) - -trainer = DistributedOffPolicyTrainer(agent, env) -trainer.train( - n_actors=1, max_buffer_size=100, batch_size=1, max_updates=100, update_interval=1 -) - - -# if __name__ == '__main__': -# parser = argparse.ArgumentParser( -# description="Parameter-Server RPC based training") -# parser.add_argument( -# "--world_size", -# type=int, -# default=4, -# help="""Total number of participating processes. Should be the number -# of actors + 3.""") -# parser.add_argument( -# "--run", -# type=str, -# default="param_server", -# choices=["param_server", "buffer", "learner", "actor"], -# help="Which program to run") -# parser.add_argument( -# "--master_addr", -# type=str, -# default="localhost", -# help="""Address of master, will default to localhost if not provided. -# Master must be able to accept network traffic on the address + port.""") -# parser.add_argument( -# "--master_port", -# type=str, -# default="29500", -# help="""Port that master is listening on, will default to 29500 if not -# provided. Master must be able to accept network traffic on the host and port.""") - -# args = parser.parse_args() - -# os.environ['MASTER_ADDR'] = args.master_addr -# os.environ["MASTER_PORT"] = args.master_port - -# processes = [] -# world_size = args.world_size -# if args.run == "param_server": -# p = mp.Process(target=run_param_server, args=(world_size)) -# p.start() -# processes.append(p) -# elif args.run == "buffer": -# p = mp.Process(target=run_buffer, args=(world_size)) -# p.start() -# processes.append(p) -# # Get data to train on -# train_loader = torch.utils.data.DataLoader( -# datasets.MNIST('../data', train=True, download=True, -# transform=transforms.Compose([ -# transforms.ToTensor(), -# transforms.Normalize((0.1307,), (0.3081,)) -# ])), -# batch_size=32, shuffle=True,) -# test_loader = torch.utils.data.DataLoader( -# datasets.MNIST( -# '../data', -# train=False, -# transform=transforms.Compose([ -# transforms.ToTensor(), -# transforms.Normalize((0.1307,), (0.3081,)) -# ])), -# batch_size=32, -# shuffle=True, -# ) -# # start training worker on this node -# p = mp.Process( -# target=run_worker, -# args=( -# args.rank, -# world_size, args.num_gpus, -# train_loader, -# test_loader)) -# p.start() -# processes.append(p) - -# for p in processes: -# p.join() From 64db1c1ad20d94cae3df05243ab480e36fceefb9 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Thu, 8 Oct 2020 04:18:47 +0000 Subject: [PATCH 09/27] added support for running from multiple scripts --- examples/distributed.py | 24 ++++++++----------- genrl/distributed/actor.py | 14 +++++------ genrl/distributed/core.py | 45 +++++++++++++++++++----------------- genrl/distributed/learner.py | 9 +++----- 4 files changed, 43 insertions(+), 49 deletions(-) diff --git a/examples/distributed.py b/examples/distributed.py index 03a84149..06c9c0d4 100644 --- a/examples/distributed.py +++ b/examples/distributed.py @@ -10,17 +10,13 @@ from genrl.agents import DDPG from genrl.trainers import DistributedTrainer import gym -import argparse import torch.multiprocessing as mp -parser = argparse.ArgumentParser() -parser.add_argument("-n", type=int) -args = parser.parse_args() N_ACTORS = 2 BUFFER_SIZE = 10 MAX_ENV_STEPS = 100 -TRAIN_STEPS = 10 +TRAIN_STEPS = 50 BATCH_SIZE = 1 @@ -56,7 +52,7 @@ def train(self, parameter_server_rref, experience_server_rref): mp.set_start_method("fork") -master = Master(world_size=6, address="localhost", port=29500) +master = Master(world_size=8, address="localhost", port=29500) env = gym.make("Pendulum-v0") agent = DDPG("mlp", env) parameter_server = ParameterServer( @@ -66,17 +62,17 @@ def train(self, parameter_server_rref, experience_server_rref): experience_server = ExperienceServer("experience-0", master, buffer, rank=2) trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE) learner = LearnerNode( - "learner-0", master, parameter_server, experience_server, trainer, rank=3 + "learner-0", master, "param-0", "experience-0", trainer, rank=3 ) actors = [ ActorNode( - f"actor-{i}", - master, - parameter_server, - experience_server, - learner, - agent, - collect_experience, + name=f"actor-{i}", + master=master, + parameter_server_name="param-0", + experience_server_name="experience-0", + learner_name="learner-0", + agent=agent, + collect_experience=collect_experience, rank=i + 4, ) for i in range(N_ACTORS) diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py index e64c3459..32dd0f0e 100644 --- a/genrl/distributed/actor.py +++ b/genrl/distributed/actor.py @@ -8,22 +8,20 @@ def __init__( self, name, master, - parameter_server, - experience_server, - learner, + parameter_server_name, + experience_server_name, + learner_name, agent, collect_experience, rank=None, ): super(ActorNode, self).__init__(name, master, rank) - self.parameter_server = parameter_server - self.experience_server = experience_server self.init_proc( target=self.act, kwargs=dict( - parameter_server_name=parameter_server.name, - experience_server_name=experience_server.name, - learner_name=learner.name, + parameter_server_name=parameter_server_name, + experience_server_name=experience_server_name, + learner_name=learner_name, agent=agent, collect_experience=collect_experience, ), diff --git a/genrl/distributed/core.py b/genrl/distributed/core.py index a3fe7c66..f4f2f750 100644 --- a/genrl/distributed/core.py +++ b/genrl/distributed/core.py @@ -2,7 +2,6 @@ import threading -from abc import ABC, abstractmethod import torch.multiprocessing as mp import os import time @@ -30,6 +29,12 @@ def _store_rref(idx, rref): _rref_reg[idx] = rref +def _get_num_rrefs(): + global _rref_reg + with _global_lock: + return len(_rref_reg.keys()) + + def get_rref(idx): rref = rpc.rpc_sync("master", _get_rref, args=(idx,)) while rref is None: @@ -51,9 +56,7 @@ class Node: def __init__(self, name, master, rank): self._name = name self.master = master - if rank is None: - self._rank = master.node_count - elif rank >= 0 and rank < master.world_size: + if rank >= 0 and rank < master.world_size: self._rank = rank elif rank >= master.world_size: raise ValueError("Specified rank greater than allowed by world size") @@ -107,30 +110,30 @@ def rank(self): return self._rank -def _run_master(world_size): - print(f"Starting master at {os.getpid()}") - rpc.init_rpc("master", rank=0, world_size=world_size) - rpc.shutdown() - - class Master: - def __init__(self, world_size, address="localhost", port=29501): + def __init__(self, world_size, address="localhost", port=29501, secondary=False): set_environ(address, port) self._world_size = world_size self._address = address self._port = port - self._node_counter = 0 - self.p = mp.Process(target=_run_master, args=(world_size,)) - self.p.start() + self._secondary = secondary - def __del__(self): - if self.p is None: - raise RuntimeWarning( - "Shutting down master when it was not initialised properly" - ) + if not self._secondary: + self.p = mp.Process(target=self._run_master, args=(world_size,)) + self.p.start() else: + self.p = None + + def __del__(self): + if not self.p is None: self.p.join() + @staticmethod + def _run_master(world_size): + print(f"Starting master at {os.getpid()}") + rpc.init_rpc("master", rank=0, world_size=world_size) + rpc.shutdown() + @property def world_size(self): return self._world_size @@ -144,5 +147,5 @@ def port(self): return self._port @property - def node_count(self): - return self._node_counter + def is_secondary(self): + return self._secondary diff --git a/genrl/distributed/learner.py b/genrl/distributed/learner.py index 541e0125..ebce8407 100644 --- a/genrl/distributed/learner.py +++ b/genrl/distributed/learner.py @@ -6,17 +6,14 @@ class LearnerNode(Node): def __init__( - self, name, master, parameter_server, experience_server, trainer, rank=None + self, name, master, parameter_server_name, experience_server_name, trainer, rank=None ): super(LearnerNode, self).__init__(name, master, rank) - self.parameter_server = parameter_server - self.experience_server = experience_server - self.init_proc( target=self.learn, kwargs=dict( - parameter_server_name=self.parameter_server.name, - experience_server_name=self.experience_server.name, + parameter_server_name=parameter_server_name, + experience_server_name=experience_server_name, trainer=trainer, ), ) From 4d57a06495f79ab0b54d4fdf0c9154405a8a90c4 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Thu, 8 Oct 2020 05:12:12 +0000 Subject: [PATCH 10/27] added evaluate to trainer --- examples/distributed.py | 10 +- genrl/distributed/core.py | 2 +- genrl/distributed/learner.py | 8 +- genrl/trainers/distributed.py | 204 +++++----------------------------- 4 files changed, 42 insertions(+), 182 deletions(-) diff --git a/examples/distributed.py b/examples/distributed.py index 06c9c0d4..093d5106 100644 --- a/examples/distributed.py +++ b/examples/distributed.py @@ -12,7 +12,6 @@ import gym import torch.multiprocessing as mp - N_ACTORS = 2 BUFFER_SIZE = 10 MAX_ENV_STEPS = 100 @@ -25,7 +24,7 @@ def collect_experience(agent, experience_server_rref): done = False for i in range(MAX_ENV_STEPS): action = agent.select_action(obs) - next_obs, reward, done, info = agent.env.step(action) + next_obs, reward, done, _ = agent.env.step(action) experience_server_rref.rpc_sync().push((obs, action, reward, next_obs, done)) obs = next_obs if done: @@ -47,12 +46,13 @@ def train(self, parameter_server_rref, experience_server_rref): self.agent.update_params(batch, 1) parameter_server_rref.rpc_sync().store_weights(self.agent.get_weights()) print(f"Trainer: {i + 1} / {self.train_steps} steps completed") + self.evaluate() i += 1 mp.set_start_method("fork") -master = Master(world_size=8, address="localhost", port=29500) +master = Master(world_size=6, address="localhost", port=29500) env = gym.make("Pendulum-v0") agent = DDPG("mlp", env) parameter_server = ParameterServer( @@ -61,9 +61,7 @@ def train(self, parameter_server_rref, experience_server_rref): buffer = ReplayBuffer(BUFFER_SIZE) experience_server = ExperienceServer("experience-0", master, buffer, rank=2) trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE) -learner = LearnerNode( - "learner-0", master, "param-0", "experience-0", trainer, rank=3 -) +learner = LearnerNode("learner-0", master, "param-0", "experience-0", trainer, rank=3) actors = [ ActorNode( name=f"actor-{i}", diff --git a/genrl/distributed/core.py b/genrl/distributed/core.py index f4f2f750..5f25ce5d 100644 --- a/genrl/distributed/core.py +++ b/genrl/distributed/core.py @@ -132,7 +132,7 @@ def __del__(self): def _run_master(world_size): print(f"Starting master at {os.getpid()}") rpc.init_rpc("master", rank=0, world_size=world_size) - rpc.shutdown() + rpc.shutdown() @property def world_size(self): diff --git a/genrl/distributed/learner.py b/genrl/distributed/learner.py index ebce8407..a5a74554 100644 --- a/genrl/distributed/learner.py +++ b/genrl/distributed/learner.py @@ -6,7 +6,13 @@ class LearnerNode(Node): def __init__( - self, name, master, parameter_server_name, experience_server_name, trainer, rank=None + self, + name, + master, + parameter_server_name, + experience_server_name, + trainer, + rank=None, ): super(LearnerNode, self).__init__(name, master, rank) self.init_proc( diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py index 6f14e2e9..a353460b 100644 --- a/genrl/trainers/distributed.py +++ b/genrl/trainers/distributed.py @@ -1,14 +1,15 @@ -from abc import ABC, abstractmethod +from genrl.trainers import Trainer +import numpy as np -class DistributedTrainer(ABC): +class DistributedTrainer: def __init__(self, agent): self.agent = agent + self.env = self.agent.env self._completed_training_flag = False - @abstractmethod def train(self, parameter_server_rref, experience_server_rref): - pass + raise NotImplementedError def train_wrapper(self, parameter_server_rref, experience_server_rref): self._completed_training_flag = False @@ -18,173 +19,28 @@ def train_wrapper(self, parameter_server_rref, experience_server_rref): def is_done(self): return self._completed_training_flag - -# class DistributedOffPolicyTrainer: -# """Distributed Off Policy Trainer Class - -# Trainer class for Distributed Off Policy Agents - -# """ - -# def __init__( -# self, -# agent, -# env, -# buffer_server_port=None, -# param_server_port=None, -# **kwargs, -# ): -# self.env = env -# self.agent = agent -# self.buffer_server_port = buffer_server_port -# self.param_server_port = param_server_port - -# def train( -# self, n_actors, max_buffer_size, batch_size, max_updates, update_interval -# ): -# buffer_server = reverb.Server( -# tables=[ -# reverb.Table( -# name="replay_buffer", -# sampler=reverb.selectors.Uniform(), -# remover=reverb.selectors.Fifo(), -# max_size=max_buffer_size, -# rate_limiter=reverb.rate_limiters.MinSize(1), -# ) -# ], -# port=self.buffer_server_port, -# ) -# buffer_server_address = f"localhost:{buffer_server.port}" - -# param_server = reverb.Server( -# tables=[ -# reverb.Table( -# name="param_buffer", -# sampler=reverb.selectors.Uniform(), -# remover=reverb.selectors.Fifo(), -# max_size=1, -# rate_limiter=reverb.rate_limiters.MinSize(1), -# ) -# ], -# port=self.param_server_port, -# ) -# param_server_address = f"localhost:{param_server.port}" - -# actor_procs = [] -# for _ in range(n_actors): -# p = threading.Thread( -# target=run_actor, -# args=( -# copy.deepcopy(self.agent), -# copy.deepcopy(self.env), -# buffer_server_address, -# param_server_address, -# ), -# daemon=True, -# ) -# p.start() -# actor_procs.append(p) - -# learner_proc = threading.Thread( -# target=run_learner, -# args=( -# copy.deepcopy(self.agent), -# max_updates, -# update_interval, -# buffer_server_address, -# param_server_address, -# batch_size, -# ), -# daemon=True, -# ) -# learner_proc.daemon = True -# learner_proc.start() -# learner_proc.join() - -# # param_client = reverb.Client(param_server_address) -# # self.agent.replay_buffer = ReverbReplayDataset( -# # self.agent.env, buffer_server_address, batch_size -# # ) - -# # for _ in range(max_updates): -# # self.agent.update_params(update_interval) -# # params = self.agent.get_weights() -# # param_client.insert(params.values(), {"param_buffer": 1}) -# # print("weights updated") -# # # print(list(param_client.sample("param_buffer"))) - - -# def run_actor(agent, env, buffer_server_address, param_server_address): -# buffer_client = reverb.Client(buffer_server_address) -# param_client = reverb.TFClient(param_server_address) - -# state = env.reset().astype(np.float32) - -# for i in range(10): -# # params = param_client.sample("param_buffer", []) -# # print("Sampling done") -# # print(list(params)) -# # agent.load_weights(params) - -# action = agent.select_action(state).numpy() -# next_state, reward, done, _ = env.step(action) -# next_state = next_state.astype(np.float32) -# reward = np.array([reward]).astype(np.float32) -# done = np.array([done]).astype(np.bool) - -# buffer_client.insert([state, action, reward, next_state, done], {"replay_buffer": 1}) -# print("transition inserted") -# state = env.reset().astype(np.float32) if done else next_state.copy() - - -# def run_learner( -# agent, -# max_updates, -# update_interval, -# buffer_server_address, -# param_server_address, -# batch_size, -# ): -# param_client = reverb.Client(param_server_address) -# agent.replay_buffer = ReverbReplayDataset( -# agent.env, buffer_server_address, batch_size -# ) -# for _ in range(max_updates): -# agent.update_params(update_interval) -# params = agent.get_weights() -# param_client.insert(params.values(), {"param_buffer": 1}) -# print("weights updated") -# # print(list(param_client.sample("param_buffer"))) - - -# class ReverbReplayDataset: -# def __init__(self, env, address, batch_size): -# action_dtype = ( -# np.int64 -# if isinstance(env.action_space, gym.spaces.discrete.Discrete) -# else np.float32 -# ) -# obs_shape = env.observation_space.shape -# action_shape = env.action_space.shape -# reward_shape = 1 -# done_shape = 1 - -# self._dataset = reverb.ReplayDataset( -# server_address=address, -# table="replay_buffer", -# max_in_flight_samples_per_worker=2 * batch_size, -# dtypes=(np.float32, action_dtype, np.float32, np.float32, np.bool), -# shapes=( -# tf.TensorShape(obs_shape), -# tf.TensorShape(action_shape), -# tf.TensorShape(reward_shape), -# tf.TensorShape(obs_shape), -# tf.TensorShape(done_shape), -# ), -# ) -# self._data_iter = self._dataset.batch(batch_size).as_numpy_iterator() - -# def sample(self, *args, **kwargs): -# sample = next(self._data_iter) -# obs, a, r, next_obs, d = [torch.from_numpy(t).float() for t in sample.data] -# return obs, a, r, next_obs, d + def evaluate(self, render: bool = False) -> None: + """Evaluate performance of Agent + + Args: + render (bool): Option to render the environment during evaluation + """ + episode_reward = 0 + episode_rewards = [] + state = self.env.reset() + done = False + for i in range(10): + while not done: + action = self.agent.select_action(state, deterministic=True) + next_state, reward, done, _ = self.env.step(action) + episode_reward += reward + state = next_state + episode_rewards.append(episode_reward) + episode_reward = 0 + print( + "Evaluated for {} episodes, Mean Reward: {:.2f}, Std Deviation for the Reward: {:.2f}".format( + 10, + np.mean(episode_rewards), + np.std(episode_rewards), + ) + ) From f325429c5fb8a649ff85dea8c2d5da205aae2ec5 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Fri, 23 Oct 2020 06:45:03 +0000 Subject: [PATCH 11/27] added proxy getter --- examples/distributed.py | 20 +++++++------------- genrl/distributed/actor.py | 17 ++++++++--------- genrl/distributed/core.py | 16 ++++++++++++++-- genrl/distributed/experience_server.py | 3 +-- genrl/distributed/learner.py | 13 +++++-------- genrl/distributed/parameter_server.py | 5 ++--- 6 files changed, 37 insertions(+), 37 deletions(-) diff --git a/examples/distributed.py b/examples/distributed.py index 093d5106..8d5dda0f 100644 --- a/examples/distributed.py +++ b/examples/distributed.py @@ -4,13 +4,11 @@ ParameterServer, ActorNode, LearnerNode, - WeightHolder, ) from genrl.core import ReplayBuffer from genrl.agents import DDPG from genrl.trainers import DistributedTrainer import gym -import torch.multiprocessing as mp N_ACTORS = 2 BUFFER_SIZE = 10 @@ -19,13 +17,13 @@ BATCH_SIZE = 1 -def collect_experience(agent, experience_server_rref): +def collect_experience(agent, experience_server): obs = agent.env.reset() done = False for i in range(MAX_ENV_STEPS): action = agent.select_action(obs) next_obs, reward, done, _ = agent.env.step(action) - experience_server_rref.rpc_sync().push((obs, action, reward, next_obs, done)) + experience_server.push((obs, action, reward, next_obs, done)) obs = next_obs if done: break @@ -37,27 +35,23 @@ def __init__(self, agent, train_steps, batch_size): self.train_steps = train_steps self.batch_size = batch_size - def train(self, parameter_server_rref, experience_server_rref): + def train(self, parameter_server, experience_server): i = 0 while i < self.train_steps: - batch = experience_server_rref.rpc_sync().sample(self.batch_size) + batch = experience_server.sample(self.batch_size) if batch is None: continue self.agent.update_params(batch, 1) - parameter_server_rref.rpc_sync().store_weights(self.agent.get_weights()) + parameter_server.store_weights(self.agent.get_weights()) print(f"Trainer: {i + 1} / {self.train_steps} steps completed") self.evaluate() i += 1 -mp.set_start_method("fork") - -master = Master(world_size=6, address="localhost", port=29500) +master = Master(world_size=6, address="localhost", port=29500, proc_start_method="fork") env = gym.make("Pendulum-v0") agent = DDPG("mlp", env) -parameter_server = ParameterServer( - "param-0", master, WeightHolder(agent.get_weights()), rank=1 -) +parameter_server = ParameterServer("param-0", master, agent.get_weights(), rank=1) buffer = ReplayBuffer(BUFFER_SIZE) experience_server = ExperienceServer("experience-0", master, buffer, rank=2) trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE) diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py index 32dd0f0e..923fffd5 100644 --- a/genrl/distributed/actor.py +++ b/genrl/distributed/actor.py @@ -1,5 +1,5 @@ from genrl.distributed.core import Node -from genrl.distributed.core import get_rref, store_rref +from genrl.distributed.core import get_proxy, store_rref import torch.distributed.rpc as rpc @@ -42,14 +42,13 @@ def act( ): rpc.init_rpc(name=name, world_size=world_size, rank=rank) print(f"{name}: RPC Initialised") - rref = rpc.RRef(agent) - store_rref(name, rref) - parameter_server_rref = get_rref(parameter_server_name) - experience_server_rref = get_rref(experience_server_name) - learner_rref = get_rref(learner_name) + store_rref(name, rpc.RRef(agent)) + parameter_server = get_proxy(parameter_server_name) + experience_server = get_proxy(experience_server_name) + learner = get_proxy(learner_name) print(f"{name}: Begining experience collection") - while not learner_rref.rpc_sync().is_done(): - agent.load_weights(parameter_server_rref.rpc_sync().get_weights()) - collect_experience(agent, experience_server_rref) + while not learner.is_done(): + agent.load_weights(parameter_server.get_weights()) + collect_experience(agent, experience_server) rpc.shutdown() diff --git a/genrl/distributed/core.py b/genrl/distributed/core.py index 5f25ce5d..009eccce 100644 --- a/genrl/distributed/core.py +++ b/genrl/distributed/core.py @@ -47,6 +47,10 @@ def store_rref(idx, rref): rpc.rpc_sync("master", _store_rref, args=(idx, rref)) +def get_proxy(idx): + return get_rref(idx).rpc_sync() + + def set_environ(address, port): os.environ["MASTER_ADDR"] = str(address) os.environ["MASTER_PORT"] = str(port) @@ -111,7 +115,15 @@ def rank(self): class Master: - def __init__(self, world_size, address="localhost", port=29501, secondary=False): + def __init__( + self, + world_size, + address="localhost", + port=29501, + secondary=False, + proc_start_method="fork", + ): + mp.set_start_method(proc_start_method) set_environ(address, port) self._world_size = world_size self._address = address @@ -130,7 +142,7 @@ def __del__(self): @staticmethod def _run_master(world_size): - print(f"Starting master at {os.getpid()}") + print(f"Starting master with pid {os.getpid()}") rpc.init_rpc("master", rank=0, world_size=world_size) rpc.shutdown() diff --git a/genrl/distributed/experience_server.py b/genrl/distributed/experience_server.py index 95335569..3dae39fa 100644 --- a/genrl/distributed/experience_server.py +++ b/genrl/distributed/experience_server.py @@ -17,7 +17,6 @@ def __init__(self, name, master, buffer, rank=None): def run_paramater_server(name, world_size, rank, buffer, **kwargs): rpc.init_rpc(name=name, world_size=world_size, rank=rank) print(f"{name}: Initialised RPC") - rref = rpc.RRef(buffer) - store_rref(name, rref) + store_rref(name, rpc.RRef(buffer)) print(f"{name}: Serving experience buffer") rpc.shutdown() diff --git a/genrl/distributed/learner.py b/genrl/distributed/learner.py index a5a74554..0700a666 100644 --- a/genrl/distributed/learner.py +++ b/genrl/distributed/learner.py @@ -1,5 +1,5 @@ from genrl.distributed import Node -from genrl.distributed.core import get_rref, store_rref +from genrl.distributed.core import get_proxy, store_rref import torch.distributed.rpc as rpc @@ -37,12 +37,9 @@ def learn( ): rpc.init_rpc(name=name, world_size=world_size, rank=rank) print(f"{name}: Initialised RPC") - rref = rpc.RRef(trainer) - store_rref(name, rref) - parameter_server_rref = get_rref(parameter_server_name) - experience_server_rref = get_rref( - experience_server_name, - ) + store_rref(name, rpc.RRef(trainer)) + parameter_server = get_proxy(parameter_server_name) + experience_server = get_proxy(experience_server_name) print(f"{name}: Beginning training") - trainer.train_wrapper(parameter_server_rref, experience_server_rref) + trainer.train_wrapper(parameter_server, experience_server) rpc.shutdown() diff --git a/genrl/distributed/parameter_server.py b/genrl/distributed/parameter_server.py index ae5ec805..0675d825 100644 --- a/genrl/distributed/parameter_server.py +++ b/genrl/distributed/parameter_server.py @@ -17,9 +17,8 @@ def __init__(self, name, master, init_params, rank=None): def run_paramater_server(name, world_size, rank, init_params, **kwargs): rpc.init_rpc(name=name, world_size=world_size, rank=rank) print(f"{name}: Initialised RPC") - params = init_params - rref = rpc.RRef(params) - store_rref(name, rref) + params = WeightHolder(init_weights=init_params) + store_rref(name, rpc.RRef(params)) print(f"{name}: Serving parameters") rpc.shutdown() From 7ce19ec120d014748172934ebf6b8b3fb80e0a80 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Fri, 23 Oct 2020 07:04:35 +0000 Subject: [PATCH 12/27] added rpc backend option --- examples/distributed.py | 10 +++++++++- genrl/distributed/actor.py | 3 ++- genrl/distributed/core.py | 24 +++++++++++++++++++++--- genrl/distributed/experience_server.py | 4 ++-- genrl/distributed/learner.py | 3 ++- genrl/distributed/parameter_server.py | 6 ++++-- 6 files changed, 40 insertions(+), 10 deletions(-) diff --git a/examples/distributed.py b/examples/distributed.py index 8d5dda0f..f1075ae3 100644 --- a/examples/distributed.py +++ b/examples/distributed.py @@ -9,6 +9,8 @@ from genrl.agents import DDPG from genrl.trainers import DistributedTrainer import gym +import torch.distributed.rpc as rpc + N_ACTORS = 2 BUFFER_SIZE = 10 @@ -48,7 +50,13 @@ def train(self, parameter_server, experience_server): i += 1 -master = Master(world_size=6, address="localhost", port=29500, proc_start_method="fork") +master = Master( + world_size=6, + address="localhost", + port=29500, + proc_start_method="fork", + rpc_backend=rpc.BackendType.TENSORPIPE, +) env = gym.make("Pendulum-v0") agent = DDPG("mlp", env) parameter_server = ParameterServer("param-0", master, agent.get_weights(), rank=1) diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py index 923fffd5..539f19af 100644 --- a/genrl/distributed/actor.py +++ b/genrl/distributed/actor.py @@ -38,9 +38,10 @@ def act( learner_name, agent, collect_experience, + rpc_backend, **kwargs, ): - rpc.init_rpc(name=name, world_size=world_size, rank=rank) + rpc.init_rpc(name=name, world_size=world_size, rank=rank, backend=rpc_backend) print(f"{name}: RPC Initialised") store_rref(name, rpc.RRef(agent)) parameter_server = get_proxy(parameter_server_name) diff --git a/genrl/distributed/core.py b/genrl/distributed/core.py index 009eccce..1adf6805 100644 --- a/genrl/distributed/core.py +++ b/genrl/distributed/core.py @@ -92,6 +92,7 @@ def init_proc(self, target, kwargs): master_port=self.master.port, world_size=self.master.world_size, rank=self.rank, + rpc_backend=self.master.rpc_backend, ) ) self.p = mp.Process(target=self._target_wrapper, args=(target,), kwargs=kwargs) @@ -122,6 +123,7 @@ def __init__( port=29501, secondary=False, proc_start_method="fork", + rpc_backend=rpc.BackendType.PROCESS_GROUP, ): mp.set_start_method(proc_start_method) set_environ(address, port) @@ -129,9 +131,21 @@ def __init__( self._address = address self._port = port self._secondary = secondary + self._rpc_backend = rpc_backend + + print( + "Configuration - {\n" + f"RPC Address : {self.address}\n" + f"RPC Port : {self.port}\n" + f"RPC World Size : {self.world_size}\n" + f"RPC Backend : {self.rpc_backend}\n" + f"Process Start Method : {proc_start_method}\n" + f"Seondary Master : {self.is_secondary}\n" + "}" + ) if not self._secondary: - self.p = mp.Process(target=self._run_master, args=(world_size,)) + self.p = mp.Process(target=self._run_master, args=(world_size, rpc_backend)) self.p.start() else: self.p = None @@ -141,9 +155,9 @@ def __del__(self): self.p.join() @staticmethod - def _run_master(world_size): + def _run_master(world_size, rpc_backend): print(f"Starting master with pid {os.getpid()}") - rpc.init_rpc("master", rank=0, world_size=world_size) + rpc.init_rpc("master", rank=0, world_size=world_size, backend=rpc_backend) rpc.shutdown() @property @@ -161,3 +175,7 @@ def port(self): @property def is_secondary(self): return self._secondary + + @property + def rpc_backend(self): + return self._rpc_backend diff --git a/genrl/distributed/experience_server.py b/genrl/distributed/experience_server.py index 3dae39fa..c2c7bccb 100644 --- a/genrl/distributed/experience_server.py +++ b/genrl/distributed/experience_server.py @@ -14,8 +14,8 @@ def __init__(self, name, master, buffer, rank=None): self.start_proc() @staticmethod - def run_paramater_server(name, world_size, rank, buffer, **kwargs): - rpc.init_rpc(name=name, world_size=world_size, rank=rank) + def run_paramater_server(name, world_size, rank, buffer, rpc_backend, **kwargs): + rpc.init_rpc(name=name, world_size=world_size, rank=rank, backend=rpc_backend) print(f"{name}: Initialised RPC") store_rref(name, rpc.RRef(buffer)) print(f"{name}: Serving experience buffer") diff --git a/genrl/distributed/learner.py b/genrl/distributed/learner.py index 0700a666..90f9f178 100644 --- a/genrl/distributed/learner.py +++ b/genrl/distributed/learner.py @@ -33,9 +33,10 @@ def learn( parameter_server_name, experience_server_name, trainer, + rpc_backend, **kwargs, ): - rpc.init_rpc(name=name, world_size=world_size, rank=rank) + rpc.init_rpc(name=name, world_size=world_size, rank=rank, backend=rpc_backend) print(f"{name}: Initialised RPC") store_rref(name, rpc.RRef(trainer)) parameter_server = get_proxy(parameter_server_name) diff --git a/genrl/distributed/parameter_server.py b/genrl/distributed/parameter_server.py index 0675d825..cde20d14 100644 --- a/genrl/distributed/parameter_server.py +++ b/genrl/distributed/parameter_server.py @@ -14,8 +14,10 @@ def __init__(self, name, master, init_params, rank=None): self.start_proc() @staticmethod - def run_paramater_server(name, world_size, rank, init_params, **kwargs): - rpc.init_rpc(name=name, world_size=world_size, rank=rank) + def run_paramater_server( + name, world_size, rank, init_params, rpc_backend, **kwargs + ): + rpc.init_rpc(name=name, world_size=world_size, rank=rank, backend=rpc_backend) print(f"{name}: Initialised RPC") params = WeightHolder(init_weights=init_params) store_rref(name, rpc.RRef(params)) From cfba909f101efb353ddc503415630d11c980101d Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Fri, 23 Oct 2020 07:31:29 +0000 Subject: [PATCH 13/27] added logging to trainer --- examples/distributed.py | 7 ++++--- genrl/agents/deep/ddpg/ddpg.py | 2 +- genrl/trainers/distributed.py | 16 +++++++++------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/examples/distributed.py b/examples/distributed.py index f1075ae3..f251861d 100644 --- a/examples/distributed.py +++ b/examples/distributed.py @@ -8,6 +8,7 @@ from genrl.core import ReplayBuffer from genrl.agents import DDPG from genrl.trainers import DistributedTrainer +from genrl.utils import Logger import gym import torch.distributed.rpc as rpc @@ -36,6 +37,7 @@ def __init__(self, agent, train_steps, batch_size): super(MyTrainer, self).__init__(agent) self.train_steps = train_steps self.batch_size = batch_size + self.logger = Logger(formats=["stdout"]) def train(self, parameter_server, experience_server): i = 0 @@ -43,10 +45,9 @@ def train(self, parameter_server, experience_server): batch = experience_server.sample(self.batch_size) if batch is None: continue - self.agent.update_params(batch, 1) + self.agent.update_params(1, batch) parameter_server.store_weights(self.agent.get_weights()) - print(f"Trainer: {i + 1} / {self.train_steps} steps completed") - self.evaluate() + self.evaluate(i) i += 1 diff --git a/genrl/agents/deep/ddpg/ddpg.py b/genrl/agents/deep/ddpg/ddpg.py index 9ed54a02..4efbbfdf 100644 --- a/genrl/agents/deep/ddpg/ddpg.py +++ b/genrl/agents/deep/ddpg/ddpg.py @@ -79,7 +79,7 @@ def _create_model(self) -> None: self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy) self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value) - def update_params(self, batch, update_interval: int) -> None: + def update_params(self, update_interval: int, batch = None) -> None: """Update parameters of the model Args: diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py index a353460b..f82b2a4a 100644 --- a/genrl/trainers/distributed.py +++ b/genrl/trainers/distributed.py @@ -1,5 +1,6 @@ from genrl.trainers import Trainer import numpy as np +from genrl.utils import safe_mean class DistributedTrainer: @@ -19,7 +20,7 @@ def train_wrapper(self, parameter_server_rref, experience_server_rref): def is_done(self): return self._completed_training_flag - def evaluate(self, render: bool = False) -> None: + def evaluate(self, timestep, render: bool = False) -> None: """Evaluate performance of Agent Args: @@ -37,10 +38,11 @@ def evaluate(self, render: bool = False) -> None: state = next_state episode_rewards.append(episode_reward) episode_reward = 0 - print( - "Evaluated for {} episodes, Mean Reward: {:.2f}, Std Deviation for the Reward: {:.2f}".format( - 10, - np.mean(episode_rewards), - np.std(episode_rewards), - ) + self.logger.write( + { + "timestep": timestep, + **self.agent.get_logging_params(), + "Episode Reward": safe_mean(episode_rewards), + }, + "timestep", ) From 992a3a910061873cb56dfc3baa4e83375099f8c6 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Fri, 23 Oct 2020 10:56:10 +0000 Subject: [PATCH 14/27] Added more options to trainer --- examples/distributed.py | 28 ++++++++++++++++++---------- genrl/core/buffers.py | 2 +- genrl/trainers/distributed.py | 13 +++++++------ 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/examples/distributed.py b/examples/distributed.py index f251861d..9085d61a 100644 --- a/examples/distributed.py +++ b/examples/distributed.py @@ -11,20 +11,27 @@ from genrl.utils import Logger import gym import torch.distributed.rpc as rpc +import time N_ACTORS = 2 -BUFFER_SIZE = 10 -MAX_ENV_STEPS = 100 -TRAIN_STEPS = 50 -BATCH_SIZE = 1 +BUFFER_SIZE = 5000 +MAX_ENV_STEPS = 500 +TRAIN_STEPS = 100 +BATCH_SIZE = 64 +INIT_BUFFER_SIZE = 500 +WARMUP_STEPS = 500 def collect_experience(agent, experience_server): obs = agent.env.reset() done = False for i in range(MAX_ENV_STEPS): - action = agent.select_action(obs) + action = ( + agent.env.action_space.sample() + if i < WARMUP_STEPS + else agent.select_action(obs) + ) next_obs, reward, done, _ = agent.env.step(action) experience_server.push((obs, action, reward, next_obs, done)) obs = next_obs @@ -33,22 +40,23 @@ def collect_experience(agent, experience_server): class MyTrainer(DistributedTrainer): - def __init__(self, agent, train_steps, batch_size): + def __init__(self, agent, train_steps, batch_size, init_buffer_size): super(MyTrainer, self).__init__(agent) self.train_steps = train_steps self.batch_size = batch_size + self.init_buffer_size = init_buffer_size self.logger = Logger(formats=["stdout"]) def train(self, parameter_server, experience_server): - i = 0 - while i < self.train_steps: + while experience_server.__len__() < self.init_buffer_size: + time.sleep(1) + for i in range(self.train_steps): batch = experience_server.sample(self.batch_size) if batch is None: continue self.agent.update_params(1, batch) parameter_server.store_weights(self.agent.get_weights()) self.evaluate(i) - i += 1 master = Master( @@ -63,7 +71,7 @@ def train(self, parameter_server, experience_server): parameter_server = ParameterServer("param-0", master, agent.get_weights(), rank=1) buffer = ReplayBuffer(BUFFER_SIZE) experience_server = ExperienceServer("experience-0", master, buffer, rank=2) -trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE) +trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE, INIT_BUFFER_SIZE) learner = LearnerNode("learner-0", master, "param-0", "experience-0", trainer, rank=3) actors = [ ActorNode( diff --git a/genrl/core/buffers.py b/genrl/core/buffers.py index b73067f1..207c7d0b 100644 --- a/genrl/core/buffers.py +++ b/genrl/core/buffers.py @@ -73,7 +73,7 @@ def __len__(self) -> int: :returns: Length of replay memory """ - return self.pos + return len(self.memory) class PrioritizedBuffer: diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py index f82b2a4a..1841eca8 100644 --- a/genrl/trainers/distributed.py +++ b/genrl/trainers/distributed.py @@ -9,12 +9,12 @@ def __init__(self, agent): self.env = self.agent.env self._completed_training_flag = False - def train(self, parameter_server_rref, experience_server_rref): + def train(self, parameter_server, experience_server): raise NotImplementedError - def train_wrapper(self, parameter_server_rref, experience_server_rref): + def train_wrapper(self, parameter_server, experience_server): self._completed_training_flag = False - self.train(parameter_server_rref, experience_server_rref) + self.train(parameter_server, experience_server) self._completed_training_flag = True def is_done(self): @@ -26,11 +26,11 @@ def evaluate(self, timestep, render: bool = False) -> None: Args: render (bool): Option to render the environment during evaluation """ - episode_reward = 0 episode_rewards = [] - state = self.env.reset() - done = False for i in range(10): + state = self.env.reset() + done = False + episode_reward = 0 while not done: action = self.agent.select_action(state, deterministic=True) next_state, reward, done, _ = self.env.step(action) @@ -38,6 +38,7 @@ def evaluate(self, timestep, render: bool = False) -> None: state = next_state episode_rewards.append(episode_reward) episode_reward = 0 + self.logger.write( { "timestep": timestep, From bf1a50a70c31fa315f0a474bd36fc962676c7ae2 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Fri, 23 Oct 2020 13:09:45 +0000 Subject: [PATCH 15/27] moved load weights to user --- examples/distributed.py | 16 ++++++++++------ genrl/distributed/actor.py | 3 +-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/distributed.py b/examples/distributed.py index 9085d61a..997afa15 100644 --- a/examples/distributed.py +++ b/examples/distributed.py @@ -17,13 +17,14 @@ N_ACTORS = 2 BUFFER_SIZE = 5000 MAX_ENV_STEPS = 500 -TRAIN_STEPS = 100 +TRAIN_STEPS = 5000 BATCH_SIZE = 64 -INIT_BUFFER_SIZE = 500 -WARMUP_STEPS = 500 +INIT_BUFFER_SIZE = 1000 +WARMUP_STEPS = 1000 -def collect_experience(agent, experience_server): +def collect_experience(agent, parameter_server, experience_server): + agent.load_weights(parameter_server.get_weights()) obs = agent.env.reset() done = False for i in range(MAX_ENV_STEPS): @@ -37,15 +38,17 @@ def collect_experience(agent, experience_server): obs = next_obs if done: break + time.sleep(1) class MyTrainer(DistributedTrainer): - def __init__(self, agent, train_steps, batch_size, init_buffer_size): + def __init__(self, agent, train_steps, batch_size, init_buffer_size, log_interval=200): super(MyTrainer, self).__init__(agent) self.train_steps = train_steps self.batch_size = batch_size self.init_buffer_size = init_buffer_size self.logger = Logger(formats=["stdout"]) + self.log_interval = log_interval def train(self, parameter_server, experience_server): while experience_server.__len__() < self.init_buffer_size: @@ -56,7 +59,8 @@ def train(self, parameter_server, experience_server): continue self.agent.update_params(1, batch) parameter_server.store_weights(self.agent.get_weights()) - self.evaluate(i) + if i % self.log_interval == 0: + self.evaluate(i) master = Master( diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py index 539f19af..f22209c7 100644 --- a/genrl/distributed/actor.py +++ b/genrl/distributed/actor.py @@ -49,7 +49,6 @@ def act( learner = get_proxy(learner_name) print(f"{name}: Begining experience collection") while not learner.is_done(): - agent.load_weights(parameter_server.get_weights()) - collect_experience(agent, experience_server) + collect_experience(agent, parameter_server, experience_server) rpc.shutdown() From e2eef667cf36717fb2c209a70059be7f26ac23ba Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Fri, 23 Oct 2020 13:09:56 +0000 Subject: [PATCH 16/27] decreased number of eval its --- genrl/trainers/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py index 1841eca8..c254a0b0 100644 --- a/genrl/trainers/distributed.py +++ b/genrl/trainers/distributed.py @@ -27,7 +27,7 @@ def evaluate(self, timestep, render: bool = False) -> None: render (bool): Option to render the environment during evaluation """ episode_rewards = [] - for i in range(10): + for i in range(5): state = self.env.reset() done = False episode_reward = 0 From 837eb18878083e8a53c360cbc8023b74efd95f9f Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Fri, 23 Oct 2020 13:45:43 +0000 Subject: [PATCH 17/27] removed train wrapper --- genrl/distributed/actor.py | 2 +- genrl/distributed/learner.py | 3 ++- genrl/trainers/distributed.py | 10 ++++------ 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py index f22209c7..513376d2 100644 --- a/genrl/distributed/actor.py +++ b/genrl/distributed/actor.py @@ -48,7 +48,7 @@ def act( experience_server = get_proxy(experience_server_name) learner = get_proxy(learner_name) print(f"{name}: Begining experience collection") - while not learner.is_done(): + while not learner.is_completed(): collect_experience(agent, parameter_server, experience_server) rpc.shutdown() diff --git a/genrl/distributed/learner.py b/genrl/distributed/learner.py index 90f9f178..c9181052 100644 --- a/genrl/distributed/learner.py +++ b/genrl/distributed/learner.py @@ -42,5 +42,6 @@ def learn( parameter_server = get_proxy(parameter_server_name) experience_server = get_proxy(experience_server_name) print(f"{name}: Beginning training") - trainer.train_wrapper(parameter_server, experience_server) + trainer.train(parameter_server, experience_server) + trainer.set_completed(True) rpc.shutdown() diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py index c254a0b0..a93b7498 100644 --- a/genrl/trainers/distributed.py +++ b/genrl/trainers/distributed.py @@ -12,14 +12,12 @@ def __init__(self, agent): def train(self, parameter_server, experience_server): raise NotImplementedError - def train_wrapper(self, parameter_server, experience_server): - self._completed_training_flag = False - self.train(parameter_server, experience_server) - self._completed_training_flag = True - - def is_done(self): + def is_completed(self): return self._completed_training_flag + def set_completed(self, value=True): + self._completed_training_flag = value + def evaluate(self, timestep, render: bool = False) -> None: """Evaluate performance of Agent From 7fcbb233698ae549788b56367ef4b50fa00e592f Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Mon, 26 Oct 2020 12:06:53 +0000 Subject: [PATCH 18/27] removed loop to user fn --- examples/offpolicy_distributed_primary.py | 92 +++++++++++++++++++++++ genrl/distributed/actor.py | 9 +-- 2 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 examples/offpolicy_distributed_primary.py diff --git a/examples/offpolicy_distributed_primary.py b/examples/offpolicy_distributed_primary.py new file mode 100644 index 00000000..9e289a4e --- /dev/null +++ b/examples/offpolicy_distributed_primary.py @@ -0,0 +1,92 @@ +from genrl.distributed import ( + Master, + ExperienceServer, + ParameterServer, + ActorNode, + LearnerNode, +) +from genrl.core import ReplayBuffer +from genrl.agents import DDPG +from genrl.trainers import DistributedTrainer +from genrl.utils import Logger +import gym +import torch.distributed.rpc as rpc +import time + + +N_ACTORS = 1 +BUFFER_SIZE = 5000 +MAX_ENV_STEPS = 500 +TRAIN_STEPS = 5000 +BATCH_SIZE = 64 +INIT_BUFFER_SIZE = 1000 +WARMUP_STEPS = 1000 + + +def run_actor(agent, parameter_server, experience_server, learner): + while not learner.is_completed(): + agent.load_weights(parameter_server.get_weights()) + obs = agent.env.reset() + done = False + for i in range(MAX_ENV_STEPS): + action = ( + agent.env.action_space.sample() + if i < WARMUP_STEPS + else agent.select_action(obs) + ) + next_obs, reward, done, _ = agent.env.step(action) + experience_server.push((obs, action, reward, next_obs, done)) + obs = next_obs + if done: + break + + +class MyTrainer(DistributedTrainer): + def __init__(self, agent, train_steps, batch_size, init_buffer_size, log_interval=200): + super(MyTrainer, self).__init__(agent) + self.train_steps = train_steps + self.batch_size = batch_size + self.init_buffer_size = init_buffer_size + self.logger = Logger(formats=["stdout"]) + self.log_interval = log_interval + + def train(self, parameter_server, experience_server): + while experience_server.__len__() < self.init_buffer_size: + time.sleep(1) + for i in range(self.train_steps): + batch = experience_server.sample(self.batch_size) + if batch is None: + continue + self.agent.update_params(1, batch) + parameter_server.store_weights(self.agent.get_weights()) + if i % self.log_interval == 0: + self.evaluate(i) + + +master = Master( + world_size=5, + address="localhost", + port=29502, + proc_start_method="fork", + rpc_backend=rpc.BackendType.TENSORPIPE, +) +env = gym.make("Pendulum-v0") +agent = DDPG("mlp", env) +parameter_server = ParameterServer("param-0", master, agent.get_weights(), rank=1) +buffer = ReplayBuffer(BUFFER_SIZE) +experience_server = ExperienceServer("experience-0", master, buffer, rank=2) +trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE, INIT_BUFFER_SIZE) +learner = LearnerNode("learner-0", master, "param-0", "experience-0", trainer, rank=3) +actors = [ + ActorNode( + name=f"actor-{i}", + master=master, + parameter_server_name="param-0", + experience_server_name="experience-0", + learner_name="learner-0", + agent=agent, + run_actor=run_actor, + rank=i + 4, + ) + for i in range(N_ACTORS) +] diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py index 513376d2..69347fbd 100644 --- a/genrl/distributed/actor.py +++ b/genrl/distributed/actor.py @@ -12,7 +12,7 @@ def __init__( experience_server_name, learner_name, agent, - collect_experience, + run_actor, rank=None, ): super(ActorNode, self).__init__(name, master, rank) @@ -23,7 +23,7 @@ def __init__( experience_server_name=experience_server_name, learner_name=learner_name, agent=agent, - collect_experience=collect_experience, + run_actor=run_actor, ), ) self.start_proc() @@ -37,7 +37,7 @@ def act( experience_server_name, learner_name, agent, - collect_experience, + run_actor, rpc_backend, **kwargs, ): @@ -48,7 +48,6 @@ def act( experience_server = get_proxy(experience_server_name) learner = get_proxy(learner_name) print(f"{name}: Begining experience collection") - while not learner.is_completed(): - collect_experience(agent, parameter_server, experience_server) + run_actor(agent, parameter_server, experience_server, learner) rpc.shutdown() From 0002fa4849c4a7550974be413b34bcc67d66af8e Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Mon, 26 Oct 2020 12:07:08 +0000 Subject: [PATCH 19/27] added example for secondary node --- examples/offpolicy_distributed_secondary.py | 80 +++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 examples/offpolicy_distributed_secondary.py diff --git a/examples/offpolicy_distributed_secondary.py b/examples/offpolicy_distributed_secondary.py new file mode 100644 index 00000000..9a89f6d8 --- /dev/null +++ b/examples/offpolicy_distributed_secondary.py @@ -0,0 +1,80 @@ +from genrl.distributed import ( + Master, + ExperienceServer, + ParameterServer, + ActorNode, + LearnerNode, + WeightHolder, +) +from genrl.core import ReplayBuffer +from genrl.agents import DDPG +from genrl.trainers import DistributedTrainer +import gym +import argparse +import torch.multiprocessing as mp + + +N_ACTORS = 2 +BUFFER_SIZE = 10 +MAX_ENV_STEPS = 100 +TRAIN_STEPS = 10 +BATCH_SIZE = 1 + + +def collect_experience(agent, experience_server_rref): + obs = agent.env.reset() + done = False + for i in range(MAX_ENV_STEPS): + action = agent.select_action(obs) + next_obs, reward, done, info = agent.env.step(action) + experience_server_rref.rpc_sync().push((obs, action, reward, next_obs, done)) + obs = next_obs + if done: + break + + +# class MyTrainer(DistributedTrainer): +# def __init__(self, agent, train_steps, batch_size): +# super(MyTrainer, self).__init__(agent) +# self.train_steps = train_steps +# self.batch_size = batch_size + +# def train(self, parameter_server_rref, experience_server_rref): +# i = 0 +# while i < self.train_steps: +# batch = experience_server_rref.rpc_sync().sample(self.batch_size) +# if batch is None: +# continue +# self.agent.update_params(batch, 1) +# parameter_server_rref.rpc_sync().store_weights(self.agent.get_weights()) +# print(f"Trainer: {i + 1} / {self.train_steps} steps completed") +# i += 1 + + +mp.set_start_method("fork") + +master = Master(world_size=8, address="localhost", port=29500, secondary=True) +env = gym.make("Pendulum-v0") +agent = DDPG("mlp", env) +# parameter_server = ParameterServer( +# "param-0", master, WeightHolder(agent.get_weights()), rank=1 +# ) +# buffer = ReplayBuffer(BUFFER_SIZE) +# experience_server = ExperienceServer("experience-0", master, buffer, rank=2) +# trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE) +# learner = LearnerNode( +# "learner-0", master, parameter_server, experience_server, trainer, rank=3 +# ) +actors = [ + ActorNode( + name=f"actor-{i+2}", + master=master, + parameter_server_name="param-0", + experience_server_name="experience-0", + learner_name="learner-0", + agent=agent, + collect_experience=collect_experience, + rank=i + 6, + ) + for i in range(N_ACTORS) +] From bebf50ff14141e4522206659de061115716ef72f Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Mon, 26 Oct 2020 12:07:39 +0000 Subject: [PATCH 20/27] removed original exmpale --- examples/distributed.py | 92 ----------------------------------------- 1 file changed, 92 deletions(-) delete mode 100644 examples/distributed.py diff --git a/examples/distributed.py b/examples/distributed.py deleted file mode 100644 index 997afa15..00000000 --- a/examples/distributed.py +++ /dev/null @@ -1,92 +0,0 @@ -from genrl.distributed import ( - Master, - ExperienceServer, - ParameterServer, - ActorNode, - LearnerNode, -) -from genrl.core import ReplayBuffer -from genrl.agents import DDPG -from genrl.trainers import DistributedTrainer -from genrl.utils import Logger -import gym -import torch.distributed.rpc as rpc -import time - - -N_ACTORS = 2 -BUFFER_SIZE = 5000 -MAX_ENV_STEPS = 500 -TRAIN_STEPS = 5000 -BATCH_SIZE = 64 -INIT_BUFFER_SIZE = 1000 -WARMUP_STEPS = 1000 - - -def collect_experience(agent, parameter_server, experience_server): - agent.load_weights(parameter_server.get_weights()) - obs = agent.env.reset() - done = False - for i in range(MAX_ENV_STEPS): - action = ( - agent.env.action_space.sample() - if i < WARMUP_STEPS - else agent.select_action(obs) - ) - next_obs, reward, done, _ = agent.env.step(action) - experience_server.push((obs, action, reward, next_obs, done)) - obs = next_obs - if done: - break - time.sleep(1) - - -class MyTrainer(DistributedTrainer): - def __init__(self, agent, train_steps, batch_size, init_buffer_size, log_interval=200): - super(MyTrainer, self).__init__(agent) - self.train_steps = train_steps - self.batch_size = batch_size - self.init_buffer_size = init_buffer_size - self.logger = Logger(formats=["stdout"]) - self.log_interval = log_interval - - def train(self, parameter_server, experience_server): - while experience_server.__len__() < self.init_buffer_size: - time.sleep(1) - for i in range(self.train_steps): - batch = experience_server.sample(self.batch_size) - if batch is None: - continue - self.agent.update_params(1, batch) - parameter_server.store_weights(self.agent.get_weights()) - if i % self.log_interval == 0: - self.evaluate(i) - - -master = Master( - world_size=6, - address="localhost", - port=29500, - proc_start_method="fork", - rpc_backend=rpc.BackendType.TENSORPIPE, -) -env = gym.make("Pendulum-v0") -agent = DDPG("mlp", env) -parameter_server = ParameterServer("param-0", master, agent.get_weights(), rank=1) -buffer = ReplayBuffer(BUFFER_SIZE) -experience_server = ExperienceServer("experience-0", master, buffer, rank=2) -trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE, INIT_BUFFER_SIZE) -learner = LearnerNode("learner-0", master, "param-0", "experience-0", trainer, rank=3) -actors = [ - ActorNode( - name=f"actor-{i}", - master=master, - parameter_server_name="param-0", - experience_server_name="experience-0", - learner_name="learner-0", - agent=agent, - collect_experience=collect_experience, - rank=i + 4, - ) - for i in range(N_ACTORS) -] From 29bd1d6291706c1e1c6b329d32c72c22ea3ccfeb Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Mon, 26 Oct 2020 12:10:47 +0000 Subject: [PATCH 21/27] removed fn --- examples/offpolicy_distributed_primary.py | 4 ++-- genrl/distributed/actor.py | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/offpolicy_distributed_primary.py b/examples/offpolicy_distributed_primary.py index 9e289a4e..f435721c 100644 --- a/examples/offpolicy_distributed_primary.py +++ b/examples/offpolicy_distributed_primary.py @@ -23,7 +23,7 @@ WARMUP_STEPS = 1000 -def run_actor(agent, parameter_server, experience_server, learner): +def collect_experience(agent, parameter_server, experience_server, learner): while not learner.is_completed(): agent.load_weights(parameter_server.get_weights()) obs = agent.env.reset() @@ -85,7 +85,7 @@ def train(self, parameter_server, experience_server): experience_server_name="experience-0", learner_name="learner-0", agent=agent, - run_actor=run_actor, + collect_experience=collect_experience, rank=i + 4, ) for i in range(N_ACTORS) diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py index 69347fbd..abc5a2eb 100644 --- a/genrl/distributed/actor.py +++ b/genrl/distributed/actor.py @@ -12,7 +12,7 @@ def __init__( experience_server_name, learner_name, agent, - run_actor, + collect_experience, rank=None, ): super(ActorNode, self).__init__(name, master, rank) @@ -23,7 +23,7 @@ def __init__( experience_server_name=experience_server_name, learner_name=learner_name, agent=agent, - run_actor=run_actor, + collect_experience=collect_experience, ), ) self.start_proc() @@ -37,7 +37,7 @@ def act( experience_server_name, learner_name, agent, - run_actor, + collect_experience, rpc_backend, **kwargs, ): @@ -48,6 +48,5 @@ def act( experience_server = get_proxy(experience_server_name) learner = get_proxy(learner_name) print(f"{name}: Begining experience collection") - run_actor(agent, parameter_server, experience_server, learner) - + collect_experience(agent, parameter_server, experience_server, learner) rpc.shutdown() From 18536a2306b2c0ecaeb48f072b818e6c7e91c073 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Thu, 29 Oct 2020 10:42:10 +0000 Subject: [PATCH 22/27] shifted examples --- examples/{ => distributed}/offpolicy_distributed_primary.py | 0 examples/{ => distributed}/offpolicy_distributed_secondary.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename examples/{ => distributed}/offpolicy_distributed_primary.py (100%) rename examples/{ => distributed}/offpolicy_distributed_secondary.py (100%) diff --git a/examples/offpolicy_distributed_primary.py b/examples/distributed/offpolicy_distributed_primary.py similarity index 100% rename from examples/offpolicy_distributed_primary.py rename to examples/distributed/offpolicy_distributed_primary.py diff --git a/examples/offpolicy_distributed_secondary.py b/examples/distributed/offpolicy_distributed_secondary.py similarity index 100% rename from examples/offpolicy_distributed_secondary.py rename to examples/distributed/offpolicy_distributed_secondary.py From 8f859d6e37f030bdb41637a597b2b753402c4fba Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Thu, 29 Oct 2020 10:42:48 +0000 Subject: [PATCH 23/27] shifted logger to base class --- examples/distributed/offpolicy_distributed_primary.py | 1 - genrl/trainers/distributed.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/distributed/offpolicy_distributed_primary.py b/examples/distributed/offpolicy_distributed_primary.py index f435721c..79cccacd 100644 --- a/examples/distributed/offpolicy_distributed_primary.py +++ b/examples/distributed/offpolicy_distributed_primary.py @@ -47,7 +47,6 @@ def __init__(self, agent, train_steps, batch_size, init_buffer_size, log_interva self.train_steps = train_steps self.batch_size = batch_size self.init_buffer_size = init_buffer_size - self.logger = Logger(formats=["stdout"]) self.log_interval = log_interval def train(self, parameter_server, experience_server): diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py index a93b7498..550d48fc 100644 --- a/genrl/trainers/distributed.py +++ b/genrl/trainers/distributed.py @@ -1,6 +1,5 @@ -from genrl.trainers import Trainer -import numpy as np from genrl.utils import safe_mean +from genrl.utils import Logger class DistributedTrainer: @@ -8,6 +7,8 @@ def __init__(self, agent): self.agent = agent self.env = self.agent.env self._completed_training_flag = False + self.logger = Logger(formats=["stdout"]) + def train(self, parameter_server, experience_server): raise NotImplementedError From 555e290ed9892cf3dd8037cfeefaa2907e408724 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Thu, 29 Oct 2020 10:43:02 +0000 Subject: [PATCH 24/27] added on policy example --- examples/deep2.py | 9 + .../onpolicy_distributed_primary.py | 231 ++++++++++++++++++ 2 files changed, 240 insertions(+) create mode 100644 examples/deep2.py create mode 100644 examples/distributed/onpolicy_distributed_primary.py diff --git a/examples/deep2.py b/examples/deep2.py new file mode 100644 index 00000000..836fb133 --- /dev/null +++ b/examples/deep2.py @@ -0,0 +1,9 @@ +from genrl.agents import DDPG +from genrl.environments import VectorEnv +from genrl.trainers import OffPolicyTrainer + +env = VectorEnv("Pendulum-v0", n_envs=1) +agent = DDPG("mlp", env) +trainer = OffPolicyTrainer(agent, env, max_timesteps=20000) +trainer.train() +trainer.evaluate() diff --git a/examples/distributed/onpolicy_distributed_primary.py b/examples/distributed/onpolicy_distributed_primary.py new file mode 100644 index 00000000..99f80112 --- /dev/null +++ b/examples/distributed/onpolicy_distributed_primary.py @@ -0,0 +1,231 @@ +from genrl.distributed import ( + Master, + ExperienceServer, + ParameterServer, + ActorNode, + LearnerNode, +) +from genrl.core.policies import MlpPolicy +from genrl.core.values import MlpValue +from genrl.trainers import DistributedTrainer +import gym +import torch.distributed.rpc as rpc +import torch +from genrl.utils import get_env_properties + +N_ACTORS = 1 +BUFFER_SIZE = 20 +MAX_ENV_STEPS = 500 +TRAIN_STEPS = 1000 + + +def get_advantages_returns(rewards, dones, values, gamma=0.99, gae_lambda=1): + buffer_size = len(rewards) + advantages = torch.zeros_like(values) + last_value = values.flatten() + last_gae_lam = 0 + for step in reversed(range(buffer_size)): + if step == buffer_size - 1: + next_non_terminal = 1.0 - dones + next_value = last_value + else: + next_non_terminal = 1.0 - dones[step + 1] + next_value = values[step + 1] + delta = ( + rewards[step] + + gamma * next_value * next_non_terminal + - values[step] + ) + last_gae_lam = ( + delta + + gamma * gae_lambda * next_non_terminal * last_gae_lam + ) + advantages[step] = last_gae_lam + returns = advantages + values + return advantages, returns + + +def unroll_trajs(trajectories): + size = sum([len(traj) for traj in trajectories]) + obs = torch.zeros(size, *trajectories[0].states[0].shape) + actions = torch.zeros(size, *trajectories[0].actions[0].shape) + rewards = torch.zeros(size, *trajectories[0].rewards[0].shape) + dones = torch.zeros(size, *trajectories[0].dones[0].shape) + + i = 0 + for traj in trajectories: + for j in range(len(traj)): + obs[i] = traj.states[j] + actions[i] = traj.actions[j] + rewards[i] = traj.rewards[j] + dones[i] = traj.dones[j] + + return obs, actions, rewards, dones + + +class A2CActor: + def __init__(self, env, policy, value, policy_optim, value_optim, grad_norm_limit=0.5): + self.env = env + self.policy = policy + self.value = value + self.policy_optim = policy_optim + self.value_optim = value_optim + self.grad_norm_limit = grad_norm_limit + + def select_action(self, obs: torch.Tensor) -> tuple: + action_probs = self.policy(obs) + distribution = torch.distributions.Categorical(probs=action_probs) + action = distribution.sample() + return action, {"log_probs": distribution.log_prob(action)} + + def update_params(self, trajectories): + obs, actions, rewards, dones = unroll_trajs(trajectories) + values = self.value(obs) + dist = torch.distributions.Categorical(self.policy(obs)) + log_probs = dist.log_prob(actions) + entropy = dist.entropy() + advantages, returns = get_advantages_returns(obs, rewards, dones, values) + + policy_loss = -torch.mean(advantages * log_probs) - torch.mean(entropy) + value_loss = torch.nn.fucntion.mse_loss(returns, values) + + self.policy_optim.zero_grad() + policy_loss.backward() + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.grad_norm_limit) + self.policy_optim.step() + + self.value_optim.zero_grad() + value_loss.backward() + torch.nn.utils.clip_grad_norm_(self.values.parameters(), self.grad_norm_limit) + self.value_optim.step() + + def get_weights(self): + return { + "policy": self.policy.state_dict(), + "value": self.value.state_dict() + } + + def load_weights(self, weights): + self.policy.load_state_dict(weights["policy"]) + self.value.load_state_dict(weights["value"]) + +class Trajectory(): + def __init__(self): + self.states = [] + self.actions = [] + self.rewards = [] + self.dones = [] + self.__len = 0 + + def add(self, state, action, reward, done): + self.states.append(state) + self.actions.append(action) + self.rewards.append(reward) + self.dones.append(done) + self.__len += 1 + + def __len__(self): + return self.__len + +class TrajBuffer(): + def __init__(self, size): + if size <= 0: + raise ValueError("Size of buffer must be larger than 0") + self._size = size + self._memory = [] + self._full = False + + def is_full(self): + return self._full + + def push(self, traj): + if not self.is_full(): + self._memory.append(traj) + if len(self._memory) >= self._size: + self._full = True + + def get(self, batch_size=None): + if batch_size is None: + return self._memory + + +def collect_experience(agent, parameter_server, experience_server, learner): + current_train_step = -1 + while not learner.is_completed(): + traj = Trajectory() + while not learner.current_train_step > current_train_step: + pass + agent.load_weights(parameter_server.get_weights()) + while not experience_server.is_full(): + obs = agent.env.reset() + done = False + for _ in range(MAX_ENV_STEPS): + action = agent.select_action(obs) + next_obs, reward, done, _ = agent.env.step(action) + traj.add(obs, action, reward, done) + obs = next_obs + if done: + break + experience_server.push(traj) + + +class MyTrainer(DistributedTrainer): + def __init__(self, agent, train_steps, log_interval=200): + super(MyTrainer, self).__init__(agent) + self.train_steps = train_steps + self.log_interval = log_interval + self._weights_available = True + self._current_train_step = 0 + + @property + def current_train_step(self): + return self._current_train_step + + def train(self, parameter_server, experience_server): + for self._current_train_step in range(self.train_steps): + if experience_server.is_full(): + self._weights_available = False + trajectories = experience_server.get() + if trajectories is None: + continue + self.agent.update_params(1, trajectories) + parameter_server.store_weights(self.agent.get_weights()) + self._weights_available = True + if self._current_train_step % self.log_interval == 0: + self.evaluate(self._current_train_step) + + +master = Master( + world_size=5, + address="localhost", + port=29502, + proc_start_method="fork", + rpc_backend=rpc.BackendType.TENSORPIPE, +) + +env = gym.make("Pendulum-v0") +state_dim, action_dim, discrete, action_lim = get_env_properties(env, "mlp") +policy = MlpPolicy(state_dim, action_dim, (32, 32), discrete) +value = MlpValue(state_dim, action_dim, "V", (32, 32)) +policy_optim = torch.optim.Adam(policy.parameters(), lr=1e-3) +value_optim = torch.optim.Adam(value.parameters(), lr=1e-3) +agent = A2CActor(env, policy, value, policy_optim, value_optim) +buffer = TrajBuffer(BUFFER_SIZE) + +parameter_server = ParameterServer("param-0", master, agent.get_weights(), rank=1) +experience_server = ExperienceServer("experience-0", master, buffer, rank=2) +trainer = MyTrainer(agent, TRAIN_STEPS) +learner = LearnerNode("learner-0", master, "param-0", "experience-0", trainer, rank=3) +actors = [ + ActorNode( + name=f"actor-{i}", + master=master, + parameter_server_name="param-0", + experience_server_name="experience-0", + learner_name="learner-0", + agent=agent, + collect_experience=collect_experience, + rank=i + 4, + ) + for i in range(N_ACTORS) +] From 59e960cedff64d31406cef7346685221be242bb3 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Thu, 29 Oct 2020 10:43:22 +0000 Subject: [PATCH 25/27] removed temp example --- examples/deep2.py | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 examples/deep2.py diff --git a/examples/deep2.py b/examples/deep2.py deleted file mode 100644 index 836fb133..00000000 --- a/examples/deep2.py +++ /dev/null @@ -1,9 +0,0 @@ -from genrl.agents import DDPG -from genrl.environments import VectorEnv -from genrl.trainers import OffPolicyTrainer - -env = VectorEnv("Pendulum-v0", n_envs=1) -agent = DDPG("mlp", env) -trainer = OffPolicyTrainer(agent, env, max_timesteps=20000) -trainer.train() -trainer.evaluate() From 8d5a8b6e827afea1e5d446aadcfdf5631f7bc9de Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Thu, 29 Oct 2020 17:50:40 +0000 Subject: [PATCH 26/27] got on policy distributed example to work --- .../onpolicy_distributed_primary.py | 90 ++++++++++--------- genrl/trainers/distributed.py | 2 +- 2 files changed, 51 insertions(+), 41 deletions(-) diff --git a/examples/distributed/onpolicy_distributed_primary.py b/examples/distributed/onpolicy_distributed_primary.py index 99f80112..85580fa5 100644 --- a/examples/distributed/onpolicy_distributed_primary.py +++ b/examples/distributed/onpolicy_distributed_primary.py @@ -12,22 +12,24 @@ import torch.distributed.rpc as rpc import torch from genrl.utils import get_env_properties +import torch.nn.functional as F +import copy +import time N_ACTORS = 1 -BUFFER_SIZE = 20 +BUFFER_SIZE = 5 MAX_ENV_STEPS = 500 -TRAIN_STEPS = 1000 +TRAIN_STEPS = 50 def get_advantages_returns(rewards, dones, values, gamma=0.99, gae_lambda=1): buffer_size = len(rewards) - advantages = torch.zeros_like(values) - last_value = values.flatten() + advantages = torch.zeros_like(rewards) last_gae_lam = 0 for step in reversed(range(buffer_size)): if step == buffer_size - 1: - next_non_terminal = 1.0 - dones - next_value = last_value + next_non_terminal = 1.0 - dones[-1] + next_value = values[-1] else: next_non_terminal = 1.0 - dones[step + 1] next_value = values[step + 1] @@ -42,28 +44,28 @@ def get_advantages_returns(rewards, dones, values, gamma=0.99, gae_lambda=1): ) advantages[step] = last_gae_lam returns = advantages + values - return advantages, returns + return advantages.detach(), returns.detach() def unroll_trajs(trajectories): size = sum([len(traj) for traj in trajectories]) obs = torch.zeros(size, *trajectories[0].states[0].shape) - actions = torch.zeros(size, *trajectories[0].actions[0].shape) - rewards = torch.zeros(size, *trajectories[0].rewards[0].shape) - dones = torch.zeros(size, *trajectories[0].dones[0].shape) + actions = torch.zeros(size) + rewards = torch.zeros(size) + dones = torch.zeros(size) i = 0 for traj in trajectories: for j in range(len(traj)): - obs[i] = traj.states[j] - actions[i] = traj.actions[j] - rewards[i] = traj.rewards[j] - dones[i] = traj.dones[j] + obs[i] = torch.tensor(traj.states[j]) + actions[i] = torch.tensor(traj.actions[j]) + rewards[i] = torch.tensor(traj.rewards[j]) + dones[i] = torch.tensor(traj.dones[j]) return obs, actions, rewards, dones -class A2CActor: +class A2C: def __init__(self, env, policy, value, policy_optim, value_optim, grad_norm_limit=0.5): self.env = env self.policy = policy @@ -72,22 +74,22 @@ def __init__(self, env, policy, value, policy_optim, value_optim, grad_norm_limi self.value_optim = value_optim self.grad_norm_limit = grad_norm_limit - def select_action(self, obs: torch.Tensor) -> tuple: - action_probs = self.policy(obs) - distribution = torch.distributions.Categorical(probs=action_probs) - action = distribution.sample() - return action, {"log_probs": distribution.log_prob(action)} + def select_action(self, obs: torch.Tensor, deterministic: bool = False): + logits = self.policy(torch.tensor(obs, dtype=torch.float)) + distribution = torch.distributions.Categorical(logits=logits) + action = torch.argmax(logits) if deterministic else distribution.sample() + return action.item() def update_params(self, trajectories): obs, actions, rewards, dones = unroll_trajs(trajectories) - values = self.value(obs) + values = self.value(obs).view(-1) dist = torch.distributions.Categorical(self.policy(obs)) log_probs = dist.log_prob(actions) entropy = dist.entropy() - advantages, returns = get_advantages_returns(obs, rewards, dones, values) + advantages, returns = get_advantages_returns(rewards, dones, values) policy_loss = -torch.mean(advantages * log_probs) - torch.mean(entropy) - value_loss = torch.nn.fucntion.mse_loss(returns, values) + value_loss = F.mse_loss(returns, values) self.policy_optim.zero_grad() policy_loss.backward() @@ -96,7 +98,7 @@ def update_params(self, trajectories): self.value_optim.zero_grad() value_loss.backward() - torch.nn.utils.clip_grad_norm_(self.values.parameters(), self.grad_norm_limit) + torch.nn.utils.clip_grad_norm_(self.value.parameters(), self.grad_norm_limit) self.value_optim.step() def get_weights(self): @@ -144,17 +146,21 @@ def push(self, traj): if len(self._memory) >= self._size: self._full = True - def get(self, batch_size=None): - if batch_size is None: - return self._memory - + def get(self, clear=True): + out = copy.deepcopy(self._memory) + if clear: + self._memory = [] + self._full = False + return out def collect_experience(agent, parameter_server, experience_server, learner): - current_train_step = -1 + current_step = -1 while not learner.is_completed(): + if not learner.current_train_step() > current_step: + time.sleep(0.5) + continue + current_step = learner.current_train_step() traj = Trajectory() - while not learner.current_train_step > current_train_step: - pass agent.load_weights(parameter_server.get_weights()) while not experience_server.is_full(): obs = agent.env.reset() @@ -165,51 +171,55 @@ def collect_experience(agent, parameter_server, experience_server, learner): traj.add(obs, action, reward, done) obs = next_obs if done: - break + break experience_server.push(traj) + print("pushed a traj") class MyTrainer(DistributedTrainer): - def __init__(self, agent, train_steps, log_interval=200): + def __init__(self, agent, train_steps, log_interval=1): super(MyTrainer, self).__init__(agent) self.train_steps = train_steps self.log_interval = log_interval self._weights_available = True self._current_train_step = 0 - @property def current_train_step(self): return self._current_train_step def train(self, parameter_server, experience_server): - for self._current_train_step in range(self.train_steps): + self._current_train_step = 0 + while True: if experience_server.is_full(): self._weights_available = False trajectories = experience_server.get() if trajectories is None: continue - self.agent.update_params(1, trajectories) + self.agent.update_params(trajectories) parameter_server.store_weights(self.agent.get_weights()) self._weights_available = True if self._current_train_step % self.log_interval == 0: self.evaluate(self._current_train_step) + self._current_train_step += 1 + if self._current_train_step >= self.train_steps: + break master = Master( - world_size=5, + world_size=N_ACTORS+4, address="localhost", - port=29502, + port=29500, proc_start_method="fork", rpc_backend=rpc.BackendType.TENSORPIPE, ) -env = gym.make("Pendulum-v0") +env = gym.make("CartPole-v0") state_dim, action_dim, discrete, action_lim = get_env_properties(env, "mlp") policy = MlpPolicy(state_dim, action_dim, (32, 32), discrete) value = MlpValue(state_dim, action_dim, "V", (32, 32)) policy_optim = torch.optim.Adam(policy.parameters(), lr=1e-3) value_optim = torch.optim.Adam(value.parameters(), lr=1e-3) -agent = A2CActor(env, policy, value, policy_optim, value_optim) +agent = A2C(env, policy, value, policy_optim, value_optim) buffer = TrajBuffer(BUFFER_SIZE) parameter_server = ParameterServer("param-0", master, agent.get_weights(), rank=1) diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py index 550d48fc..33c50b2d 100644 --- a/genrl/trainers/distributed.py +++ b/genrl/trainers/distributed.py @@ -41,7 +41,7 @@ def evaluate(self, timestep, render: bool = False) -> None: self.logger.write( { "timestep": timestep, - **self.agent.get_logging_params(), + # **self.agent.get_logging_params(), "Episode Reward": safe_mean(episode_rewards), }, "timestep", From 8030b2a77ebb607532e38eddf5d97367a79cfda2 Mon Sep 17 00:00:00 2001 From: Atharv Sonwane Date: Thu, 29 Oct 2020 17:51:14 +0000 Subject: [PATCH 27/27] formatting --- .../offpolicy_distributed_primary.py | 4 +- .../onpolicy_distributed_primary.py | 37 ++++++++----------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/examples/distributed/offpolicy_distributed_primary.py b/examples/distributed/offpolicy_distributed_primary.py index 79cccacd..c6b726a6 100644 --- a/examples/distributed/offpolicy_distributed_primary.py +++ b/examples/distributed/offpolicy_distributed_primary.py @@ -42,7 +42,9 @@ def collect_experience(agent, parameter_server, experience_server, learner): class MyTrainer(DistributedTrainer): - def __init__(self, agent, train_steps, batch_size, init_buffer_size, log_interval=200): + def __init__( + self, agent, train_steps, batch_size, init_buffer_size, log_interval=200 + ): super(MyTrainer, self).__init__(agent) self.train_steps = train_steps self.batch_size = batch_size diff --git a/examples/distributed/onpolicy_distributed_primary.py b/examples/distributed/onpolicy_distributed_primary.py index 85580fa5..9709f4f1 100644 --- a/examples/distributed/onpolicy_distributed_primary.py +++ b/examples/distributed/onpolicy_distributed_primary.py @@ -33,15 +33,8 @@ def get_advantages_returns(rewards, dones, values, gamma=0.99, gae_lambda=1): else: next_non_terminal = 1.0 - dones[step + 1] next_value = values[step + 1] - delta = ( - rewards[step] - + gamma * next_value * next_non_terminal - - values[step] - ) - last_gae_lam = ( - delta - + gamma * gae_lambda * next_non_terminal * last_gae_lam - ) + delta = rewards[step] + gamma * next_value * next_non_terminal - values[step] + last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam advantages[step] = last_gae_lam returns = advantages + values return advantages.detach(), returns.detach() @@ -66,7 +59,9 @@ def unroll_trajs(trajectories): class A2C: - def __init__(self, env, policy, value, policy_optim, value_optim, grad_norm_limit=0.5): + def __init__( + self, env, policy, value, policy_optim, value_optim, grad_norm_limit=0.5 + ): self.env = env self.policy = policy self.value = value @@ -102,16 +97,14 @@ def update_params(self, trajectories): self.value_optim.step() def get_weights(self): - return { - "policy": self.policy.state_dict(), - "value": self.value.state_dict() - } + return {"policy": self.policy.state_dict(), "value": self.value.state_dict()} def load_weights(self, weights): self.policy.load_state_dict(weights["policy"]) self.value.load_state_dict(weights["value"]) -class Trajectory(): + +class Trajectory: def __init__(self): self.states = [] self.actions = [] @@ -129,7 +122,8 @@ def add(self, state, action, reward, done): def __len__(self): return self.__len -class TrajBuffer(): + +class TrajBuffer: def __init__(self, size): if size <= 0: raise ValueError("Size of buffer must be larger than 0") @@ -139,20 +133,21 @@ def __init__(self, size): def is_full(self): return self._full - + def push(self, traj): if not self.is_full(): self._memory.append(traj) if len(self._memory) >= self._size: self._full = True - + def get(self, clear=True): - out = copy.deepcopy(self._memory) + out = copy.deepcopy(self._memory) if clear: self._memory = [] self._full = False return out + def collect_experience(agent, parameter_server, experience_server, learner): current_step = -1 while not learner.is_completed(): @@ -171,7 +166,7 @@ def collect_experience(agent, parameter_server, experience_server, learner): traj.add(obs, action, reward, done) obs = next_obs if done: - break + break experience_server.push(traj) print("pushed a traj") @@ -206,7 +201,7 @@ def train(self, parameter_server, experience_server): master = Master( - world_size=N_ACTORS+4, + world_size=N_ACTORS + 4, address="localhost", port=29500, proc_start_method="fork",