Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding unit test for end-to-end example #669

Merged
merged 28 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
fca3664
adding unit test for multi-gpu example
bbozkaya Apr 5, 2023
d24915a
added test for notebook 03
bbozkaya Apr 12, 2023
def2a60
fixed formatting
bbozkaya Apr 12, 2023
7df2bca
Merge branch 'main' into unittest_endtoend_multi
rnyak Apr 12, 2023
426e5a3
Merge branch 'main' into unittest_endtoend_multi
rnyak Apr 13, 2023
eaa0592
Merge branch 'main' into unittest_endtoend_multi
edknv Apr 18, 2023
c3b2332
Merge branch 'main' into unittest_endtoend_multi
bbozkaya May 9, 2023
160f296
update
bbozkaya May 9, 2023
06aabc7
update
bbozkaya May 9, 2023
1577c58
Update 01-ETL-with-NVTabular.ipynb
bbozkaya May 10, 2023
fc9c5c3
Update 01-ETL-with-NVTabular.ipynb
bbozkaya May 10, 2023
17bae4b
Merge branch 'main' into unittest_endtoend_multi
rnyak May 15, 2023
ec98739
reduce num_rows
rnyak May 18, 2023
213c5a1
Update test_end_to_end_session_based.py
bbozkaya May 19, 2023
236f4ff
Update 01-ETL-with-NVTabular.ipynb
bbozkaya May 19, 2023
df54cf1
Merge branch 'main' into unittest_endtoend_multi
bbozkaya May 19, 2023
5e46247
updated test script and notebook
bbozkaya May 22, 2023
24b0070
updated file
bbozkaya May 22, 2023
bd81dcd
removed nb3 test due to multi-gpu freezing issue
bbozkaya May 22, 2023
18992d5
Merge branch 'main' into unittest_endtoend_multi
rnyak May 24, 2023
7cc9bc0
Merge branch 'main' into unittest_endtoend_multi
rnyak Jun 1, 2023
c02a707
revised notebooks, added back nb3 test
bbozkaya Jun 28, 2023
f09e08f
Merge branch 'main' into unittest_endtoend_multi
bbozkaya Jun 28, 2023
253dfb7
fixed test file with black
bbozkaya Jun 28, 2023
7afc670
Merge branch 'unittest_endtoend_multi' of https://github.com/NVIDIA-M…
bbozkaya Jun 28, 2023
405ed68
update test py
bbozkaya Jul 5, 2023
4e0efcd
update test py
bbozkaya Jul 5, 2023
5cc0c06
Use `python -m torch.distributed.run` instead of `torchrun`
oliverholworthy Jul 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 123 additions & 37 deletions examples/end-to-end-session-based/01-ETL-with-NVTabular.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
"\n",
"The dataset is available on [Kaggle](https://www.kaggle.com/chadgostopp/recsys-challenge-2015). You need to download it and copy to the `DATA_FOLDER` path. Note that we are only using the `yoochoose-clicks.dat` file.\n",
"\n",
"Alternatively, you can generate a synthetic dataset with the same columns and dtypes as the `YOOCHOOSE` dataset and a default date range of 5 days. If the environment variable `USE_SYNTHETIC` is set to `True`, the code below will execute the function `generate_synthetic_data` and the rest of the notebook will run on a synthetic dataset.\n",
"\n",
"First, let's start by importing several libraries:"
]
},
Expand All @@ -75,17 +77,18 @@
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n",
" warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n",
"/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
" warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n"
]
}
],
"source": [
"import os\n",
"import glob\n",
"import numpy as np\n",
"import pandas as pd\n",
"import gc\n",
"import calendar\n",
"import datetime\n",
"\n",
"import cudf\n",
"import cupy\n",
Expand Down Expand Up @@ -128,12 +131,14 @@
"metadata": {},
"outputs": [],
"source": [
"DATA_FOLDER = \"/workspace/data/\"\n",
"DATA_FOLDER = os.environ.get(\"DATA_FOLDER\", \"/workspace/data\")\n",
"FILENAME_PATTERN = 'yoochoose-clicks.dat'\n",
"DATA_PATH = os.path.join(DATA_FOLDER, FILENAME_PATTERN)\n",
"\n",
"OUTPUT_FOLDER = \"./yoochoose_transformed\"\n",
"OVERWRITE = False"
"OVERWRITE = False\n",
"\n",
"USE_SYNTHETIC = os.environ.get(\"USE_SYNTHETIC\", False)"
]
},
{
Expand All @@ -144,16 +149,89 @@
"## Load and clean raw data"
]
},
{
"cell_type": "markdown",
"id": "3fba8546-668c-4743-960e-ea2aef99ef24",
"metadata": {},
"source": [
"Execute the cell below if you would like to work with synthetic data. Otherwise you can skip and continue with the next cell."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "07d14289-c783-45f0-86e8-e5c1001bfd76",
"metadata": {},
"outputs": [],
"source": [
"def generate_synthetic_data(\n",
" start_date: datetime.date, end_date: datetime.date, rows_per_day: int = 10000\n",
") -> pd.DataFrame:\n",
" assert end_date > start_date, \"end_date must be later than start_date\"\n",
"\n",
" number_of_days = (end_date - start_date).days\n",
" total_number_of_rows = number_of_days * rows_per_day\n",
"\n",
" # Generate a long-tail distribution of item interactions. This simulates that some items are\n",
" # more popular than others.\n",
" long_tailed_item_distribution = np.clip(\n",
" np.random.lognormal(3.0, 1.0, total_number_of_rows).astype(np.int64), 1, 50000\n",
" )\n",
"\n",
" # generate random item interaction features\n",
" df = pd.DataFrame(\n",
" {\n",
" \"session_id\": np.random.randint(70000, 80000, total_number_of_rows),\n",
" \"item_id\": long_tailed_item_distribution,\n",
" },\n",
" )\n",
"\n",
" # generate category mapping for each item-id\n",
" df[\"category\"] = pd.cut(df[\"item_id\"], bins=334, labels=np.arange(1, 335)).astype(\n",
" np.int64\n",
" )\n",
"\n",
" max_session_length = 60 * 60 # 1 hour\n",
"\n",
" def add_timestamp_to_session(session: pd.DataFrame):\n",
" random_start_date_and_time = calendar.timegm(\n",
" (\n",
" start_date\n",
" # Add day offset from start_date\n",
" + datetime.timedelta(days=np.random.randint(0, number_of_days))\n",
" # Add time offset within the random day\n",
" + datetime.timedelta(seconds=np.random.randint(0, 86_400))\n",
" ).timetuple()\n",
" )\n",
" session[\"timestamp\"] = random_start_date_and_time + np.clip(\n",
" np.random.lognormal(3.0, 1.0, len(session)).astype(np.int64),\n",
" 0,\n",
" max_session_length,\n",
" )\n",
" return session\n",
"\n",
" df = df.groupby(\"session_id\").apply(add_timestamp_to_session).reset_index()\n",
"\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f35dff52",
"metadata": {},
"outputs": [],
"source": [
"interactions_df = cudf.read_csv(DATA_PATH, sep=',', \n",
" names=['session_id','timestamp', 'item_id', 'category'], \n",
" dtype=['int', 'datetime64[s]', 'int', 'int'])"
"if USE_SYNTHETIC:\n",
" START_DATE = os.environ.get(\"START_DATE\", \"2014/4/1\")\n",
" END_DATE = os.environ.get(\"END_DATE\", \"2014/4/5\")\n",
" interactions_df = generate_synthetic_data(datetime.datetime.strptime(START_DATE, '%Y/%m/%d'),\n",
" datetime.datetime.strptime(END_DATE, '%Y/%m/%d'))\n",
" interactions_df = cudf.from_pandas(interactions_df)\n",
"else:\n",
" interactions_df = cudf.read_csv(DATA_PATH, sep=',', \n",
" names=['session_id','timestamp', 'item_id', 'category'], \n",
" dtype=['int', 'datetime64[s]', 'int', 'int'])"
]
},
{
Expand All @@ -166,7 +244,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "22c2df72",
"metadata": {},
"outputs": [
Expand All @@ -181,13 +259,16 @@
],
"source": [
"print(\"Count with in-session repeated interactions: {}\".format(len(interactions_df)))\n",
"\n",
"# Sorts the dataframe by session and timestamp, to remove consecutive repetitions\n",
"interactions_df.timestamp = interactions_df.timestamp.astype(int)\n",
"interactions_df = interactions_df.sort_values(['session_id', 'timestamp'])\n",
"past_ids = interactions_df['item_id'].shift(1).fillna()\n",
"session_past_ids = interactions_df['session_id'].shift(1).fillna()\n",
"\n",
"# Keeping only no consecutive repeated in session interactions\n",
"interactions_df = interactions_df[~((interactions_df['session_id'] == session_past_ids) & (interactions_df['item_id'] == past_ids))]\n",
"\n",
"print(\"Count after removed in-session repeated interactions: {}\".format(len(interactions_df)))"
]
},
Expand All @@ -201,7 +282,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "66a1bd13",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -234,17 +315,19 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "a0f908a1",
"metadata": {},
"outputs": [],
"source": [
"if os.path.isdir(DATA_FOLDER) == False:\n",
" os.mkdir(DATA_FOLDER)\n",
"interactions_merged_df.to_parquet(os.path.join(DATA_FOLDER, 'interactions_merged_df.parquet'))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "909f87c5-bff5-48c8-b714-cc556a4bc64d",
"metadata": {
"tags": []
Expand All @@ -265,17 +348,17 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "04a3b5b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
"517"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -330,7 +413,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 13,
"id": "86f80068",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -425,7 +508,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"id": "10b5c96c",
"metadata": {},
"outputs": [],
Expand All @@ -447,7 +530,6 @@
"# Truncate sequence features to first interacted 20 items \n",
"SESSIONS_MAX_LENGTH = 20 \n",
"\n",
"\n",
"item_feat = groupby_features['item_id-list'] >> nvt.ops.TagAsItemID()\n",
"cont_feats = groupby_features['et_dayofweek_sin-list', 'product_recency_days_log_norm-list'] >> nvt.ops.AddMetadata(tags=[Tags.CONTINUOUS])\n",
"\n",
Expand Down Expand Up @@ -491,7 +573,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"id": "45803886",
"metadata": {},
"outputs": [],
Expand All @@ -513,7 +595,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 16,
"id": "4c10efb5-89c5-4458-a634-475eb459a47c",
"metadata": {
"tags": []
Expand Down Expand Up @@ -600,7 +682,7 @@
" <tr>\n",
" <th>2</th>\n",
" <td>item_id-list</td>\n",
" <td>(Tags.CATEGORICAL, Tags.ITEM, Tags.ID, Tags.LIST)</td>\n",
" <td>(Tags.CATEGORICAL, Tags.ID, Tags.LIST, Tags.ITEM)</td>\n",
" <td>DType(name='int64', element_type=&lt;ElementType....</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
Expand Down Expand Up @@ -697,10 +779,10 @@
"</div>"
],
"text/plain": [
"[{'name': 'session_id', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}}, 'dtype': DType(name='int32', element_type=<ElementType.Int: 'int'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.ITEM: 'item'>, <Tags.ID: 'id'>, <Tags.LIST: 'list'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {<Tags.CONTINUOUS: 'continuous'>, <Tags.LIST: 'list'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float64', element_type=<ElementType.Float: 'float'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {<Tags.CONTINUOUS: 'continuous'>, <Tags.LIST: 'list'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=<ElementType.Float: 'float'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.LIST: 'list'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 336, 'name': 'category'}, 'embedding_sizes': {'cardinality': 337, 'dimension': 42}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'day_index', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]"
"[{'name': 'session_id', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}}, 'dtype': DType(name='int32', element_type=<ElementType.Int: 'int'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.ID: 'id'>, <Tags.LIST: 'list'>, <Tags.ITEM: 'item'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {<Tags.CONTINUOUS: 'continuous'>, <Tags.LIST: 'list'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float64', element_type=<ElementType.Float: 'float'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {<Tags.CONTINUOUS: 'continuous'>, <Tags.LIST: 'list'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=<ElementType.Float: 'float'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.LIST: 'list'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 336, 'name': 'category'}, 'embedding_sizes': {'cardinality': 337, 'dimension': 42}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'day_index', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]"
]
},
"execution_count": 14,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -719,7 +801,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 17,
"id": "2d035a88-2146-4b9a-96fd-dd42be86e2a1",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -747,19 +829,23 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"id": "2b4f5b73-459c-4356-87c8-9afb974cc77d",
"metadata": {},
"outputs": [],
"source": [
"# read in the processed train dataset\n",
"sessions_gdf = cudf.read_parquet(os.path.join(DATA_FOLDER, \"processed_nvt/part_0.parquet\"))\n",
"sessions_gdf = sessions_gdf[sessions_gdf.day_index>=178]"
"if USE_SYNTHETIC:\n",
" THRESHOLD_DAY_INDEX = int(os.environ.get(\"THRESHOLD_DAY_INDEX\", '1'))\n",
" sessions_gdf = sessions_gdf[sessions_gdf.day_index>=THRESHOLD_DAY_INDEX]\n",
"else:\n",
" sessions_gdf = sessions_gdf[sessions_gdf.day_index>=178]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"id": "e18d9c63",
"metadata": {},
"outputs": [
Expand All @@ -783,13 +869,13 @@
"6606149 [-0.7818309228245777, -0.7818309228245777] \n",
"\n",
" product_recency_days_log_norm-list \\\n",
"6606147 [1.5241553, 1.5238751, 1.5239341, 1.5241631, 1... \n",
"6606148 [-0.5330064, 1.521494] \n",
"6606149 [1.5338266, 1.5355074] \n",
"6606147 [1.5241561, 1.523876, 1.523935, 1.5241641, 1.5... \n",
"6606148 [-0.533007, 1.521495] \n",
"6606149 [1.5338274, 1.5355083] \n",
"\n",
" category-list day_index \n",
"6606147 [4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 4] 178 \n",
"6606148 [3, 3] 178 \n",
"6606147 [4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4] 178 \n",
"6606148 [1, 3] 178 \n",
"6606149 [8, 8] 180 \n"
]
}
Expand All @@ -800,15 +886,15 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 20,
"id": "5175aeaf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Creating time-based splits: 100%|██████████| 5/5 [00:02<00:00, 2.37it/s]\n"
"Creating time-based splits: 100%|██████████| 5/5 [00:02<00:00, 2.24it/s]\n"
]
}
],
Expand All @@ -823,17 +909,17 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 21,
"id": "3bd1bad9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"583"
"748"
]
},
"execution_count": 19,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
Loading