Skip to content

Commit

Permalink
Fix ugly patch attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
Atcold committed Oct 19, 2023
1 parent a532162 commit a9d7221
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions 15-transformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"nn_Softargmax = nn.Softmax # fix wrong name"
"f.softargmax = f.softmax # fix wrong name"
]
},
{
Expand Down Expand Up @@ -70,14 +70,14 @@
" batch_size = Q.size(0) \n",
" k_length = K.size(-2) \n",
" \n",
" # Scaling by d_k so that the soft(arg)max doesnt saturate\n",
" Q = Q / np.sqrt(self.d_k) # (bs, n_heads, q_length, dim_per_head)\n",
" scores = torch.matmul(Q, K.transpose(2,3)) # (bs, n_heads, q_length, k_length)\n",
" # Scaling by d_k so that the softargmax doesnt saturate\n",
" Q = Q / np.sqrt(self.d_k) # (bs, n_heads, q_length, dim_per_head)\n",
" scores = torch.matmul(Q, K.transpose(2,3)) # (bs, n_heads, q_length, k_length)\n",
" \n",
" A = nn_Softargmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length)\n",
" A = f.softargmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)\n",
" \n",
" # Get the weighted average of the values\n",
" H = torch.matmul(A, V) # (bs, n_heads, q_length, dim_per_head)\n",
" H = torch.matmul(A, V) # (bs, n_heads, q_length, dim_per_head)\n",
"\n",
" return H, A \n",
"\n",
Expand Down Expand Up @@ -674,9 +674,7 @@
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -749,7 +747,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 [conda env:pDL]",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -763,9 +761,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}

0 comments on commit a9d7221

Please sign in to comment.