2021년 2월 18일 목요일

AITech 학습정리-[DAY 19] Transformer, 실습-Multi-head Attention, Masked Multi-head Attention

 =================================

학습내용

(7강) Transformer

Add-on 처럼 사용되던 attention 만 사용해서 만들 수 있다. 라는걸 보여주는게 transformer



일단 RNN 부터 보자.

RNN은 왼쪽에서 오른쪽으로 가거나 반대로 해서 오른쪽에서 왼쪽으로 감. bi-direction RNNs 은 두개 다 하는데 이러면 I go home 단어를 넣는다고 했을 때 go 단어는 정방향에서 I 정보를 역방향에선 home 정보를 hidden state에서 가지게 된다. 그래서 이 두 벡터를 concat 해서 두개 다 가지게 함.


Further Reading

가중치결정 query. query 벡터와 곱해지는 재료 벡터 keys, 어느것을 가져올지. 또다른 재료역할 values 벡터

각 단어들에 대해 계산할 때 각 단어에 대한 keys 와 values 벡터는 같으나 q가 달라진다. 그래서 이 달라진 q를 가지고 k 와 곱하고 여기서 새로나온 가중치에 대한 값을 values 벡터와 선형결합해서 최종적인 Attention output vector를 얻어내는 것. 이게 결국 h2가 된다.


Queries 와 keys 벡터 차원은 같아야 하지만, values는 상관없다. 왜냐하면 어차피 모두 더해서 scalar 값을 만들기 때문.


A(q,K,V)는 values 벡터에 대한 가중평균. quary를 attention model 을 통해 encoding 한 벡터는 최종적으로 dv 차원 만큼이 나옴. 

key 와의 내적에 기반한 유사도



그럼 왜 softmax 전에 root(dk) 로 나누는가?

평균과 분산이 있으니까 두개를 더하게 되면 평균은 그대로 0이지만 분산은 늘어난다. 근데 이게 단순히 개수가 늘어나면 늘어나는 거기 때문에 k차원이 늘어나면 늘어날수록 더하는게 많아지고 그러면 분산도 더 커진다. 그래서 단순히 k차원이 늘어난다고 분산이 커지고, 분산이 커지면 굉장히 큰 값이 나오고 큰 값과 작은값의 차이가 softmax 에 상당한 영향을 주기 때문에 표준편차로 나눠서 정규화 하는 것. 즉 softmax 하기 전에 한번 정규화를 시켜줘서 k차원이 얼마나 크든 작던간에 상관없이 만들어 주는 거다.


그렇게 만든 self attention인 z를 여러개 만듬. 여기선 8개 만들어 concat 해준 뒤 나중에 입력값 vector와 그대로 더해야 하기 때문에 linear로 입력값 vector 차원만큼 줄여서 최종 값을 내놓는다.


n 은 입력 sequence의 길이, d는 query의 d 차원길이.

우리가 한 Self-Attention의 경우 입력길이, k차원 길이가 행렬곱셈 정의에 의해 O(n^2*d) 가 된다. 하지만 이것은 병렬처리가 가능하기 때문에 병렬처리의 시간복잡도 Sequential Operations는 1이 됨. 단어와의 길이도 매번마다 query vector 와 key vector를 각 단어마다 계산해서 O(1)인듯.

 

RNN의 경우 hidden state인 h의 차원이 d이고 d*d 크기의 weight와 곱셈을 하고, 이것이 입력 길이 n번만큼 수행되기 때문에 O(n*d^2). 그러나 이렇게 순차적으로 계산하는 것이기 때문에 병렬처리가 불가능해서 그냥 시간복잡도는 O(n)이 된다. 단어길이도 n번 거쳐가 도착하니까 O(n).




한 block을 구성하는 Multi-Head Attention, residual connection, Layer Normalization. Feed Forward는 fully connected layer. Transformer에서 제안한 self-attention인 Multi-Head Attention을 포함한 한 블럭.

지금까지는 multi-head attention 을 본거고, transformer에서는 이 multi-head attention을 덧붙혀서 추가적인 후처리. residual connection. gradient의 vanishing 문제를 해결하여 layer를 여러개 쌓아 올려 성능을 올리겠다. 그럼 add와 Norm을 왜 하는가? 밑에 설명할거다.


그냥 neural network의 경우(linear regression 같은걸 말하는 듯) hidden 결과가 나왔을 때 hidden state 의 평균과 분산 정보가 어쨋든 버리고 입력값에 대해 빼주고 표준편차로 나눠 정규화 시켜줬다. 그래야 원래 식 y=2x+3 에서 분산이 4이고 평균이 3이라는 원래 식 정보를 반영하기 때문인 것 같다. 그래서 특정 노드에 발견되어야 하는 값에 가장 최적화된 평균과 분산을 원하는 만큼 가지도록 동작하게 된다. 그래서 layer norm 도 batch norm과 유사하게 첫번째 단계에서는 주어진 sample에 대한 평균 분산을 0과 1로 만들고 다음엔 우리가 원하는 평균 분산을 주입하는 두 단계로 이루어져 있다.



보면 thinking과 machines 단어를 넣었을 때 나온 hidden state를 정규화 해주고 내가 원하는 평균과 분산 값을 Affine transformation을 통해 주입하는 걸 볼 수 있다. 이게 layer normalization. batch normalization은 조금 다르다고 하지만 대강 비슷하다고 함.



문제가 하나 있는데 순서에 대한 변인요소가 없어서 단어가 바뀌어도 어순을 파악못한다. 이는 values 벡터와 곱할때 교환법칙이 성립하기 때문.

그래서 각 위치에 따른 고유 값을 더해서 반영하는 식으로 순서를 반영한다. 이게 무슨 말이냐면 얘가 단어 뜻에 관계없이 어디 위치에 있으면 특정값을 더해라, 를 통해 위치도 정보에 반영한다는 것. 이걸 sin과 cos 주기함수에 주기만 다르게 해서 만들어서 각 위치에다가 더해서 쓴다.


Learning rate의 경우 저런 일들이 일어나니까 경험상 저렇게 하는게 낫다는게 알려져 있다.


우리가 배웠던 block 이 Encoder block 이었고 이것을 여러개 layer로 쌓아 올린다.

makeing 단어를 보면 각 attention 마다 다른 곳에 주목함을 알 수 있다. 이래서 multi-head attention을 하는거고.

https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb

 

보면 Encoding에서 번역되서 나온게 Decoder로 들어간다. Masked Multi-Head Attention은 밑에서 설명하기로 하고 그 위에 보면 Quary vector는 Decoder 에 들어간 입력값으로 만드는데 Keys vector와 values vector는 encoder에서 마지막으로 나온 keys 와 values가 들어가는 걸 볼 수 있다. 이는 단어와의 관계를 추정하는 것이기 때문에 그렇다.

그리고 "I go home"을 "나는 집에 간다" 로 번역할 때 마지막에 linear를 거치는데 여기서 모든 한글 단어에 대해 linear을 펼쳐서 본다. 만약 한글 사전에 단어가 10만개 있으면 10만개 output을 내놓도록 linear에서 펼쳐놓고 softmax해서 가장 가능성이 높은 걸 출력한다.



훈련할 때는 모든 단어들을 넣어주긴 하지만 실제 attention mode (디코더에서? 테스트? 말하는 듯) 일 땐 미래에 어떤 단어를 입력할 지 미리 안다는게 말이 안된다. 하지만 모델 구조가 저렇게 넣도록 되어있으니 mask를 씌워서 후처리를 하는 것 같다.


현재 입력값이 query고 나머지가 key니까 현재 입력값 query에 미래의 단어 key들에 대한 softmax 확률을 0으로 바꾸는 듯. 이게 문장이라는게 왼쪽에서 오른쪽으로 읽다보니 왼쪽 단어 보는데 오른쪽을 보고 이해하는게 말이 안되는 듯. 비록 왼쪽 단어가 오른쪽 단어를 보고 해석하는 거여도 오른쪽 단어에 왔을 때 key를 이용해 왼쪽의 단어와 연관이 있다고 하고 해석하는게 맞다고 판단한 것 같다.




실습

(실습 7강) Multi-head Attention 구현

Multi-Head 시 이론은 H개 만큼 Q, K, V를 따로 생성하는게 맞는데 메모리문제 등으로 인해 하나의 행렬로 만들고 구간을 쪼개서 사용한다고 한다. 위 예시는 head=3일 때 d_model을 head 개수만큼 나누면 head가 3개가 나오니까 진짜 head 3개가 되는거임. 길이는 나누고 나온 d_k로 하고.


##7. Multi-head Attention 1. Multi-head attention 및 self-attention 구현. 2. 각 과정에서 일어나는 연산과 input/output 형태 이해.

필요 패키지 import

[1]

데이터 전처리

[2]
[3]
[4]
100%|██████████| 10/10 [00:00<00:00, 12826.62it/s]
Maximum sequence length: 20
[5]
[[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [77,
  65,
  51,
  77,
  19,
  15,
  35,
  19,
  23,
  97,
  50,
  46,
  53,
  42,
  45,
  91,
  66,
  3,
  43,
  10],
 [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0],
 [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]]

Hyperparameter 세팅 및 embedding

[6]
[7]
[8]
tensor([[[-0.2169, -0.3583, 1.0193, ..., -0.7934, -0.9208, -1.0198], [-0.8411, 2.4772, 0.9702, ..., -0.4276, -1.3260, -0.0394], [ 0.3482, 2.8239, -1.6240, ..., 1.5651, -0.0208, -1.2387], ..., [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691]], [[-0.3289, -0.4818, -0.6164, ..., -1.8003, 0.6235, -0.3524], [-0.8258, -0.2416, 1.0993, ..., 0.8884, 0.3743, 1.3961], [ 0.7716, 0.2966, 0.1699, ..., -1.2789, -0.5366, 0.3534], ..., [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691]], [[-0.8911, -1.5136, 1.1245, ..., 1.0384, -0.5083, 0.3870], [ 0.7745, -0.3943, 0.5206, ..., -0.1021, -0.8571, -1.8890], [ 0.7830, 0.8141, -1.1696, ..., 1.6220, 1.5565, 0.6228], ..., [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691]], ..., [[ 0.4533, 0.8047, -0.1219, ..., -0.8111, 1.2460, 0.4246], [-0.3274, -1.0963, -1.2922, ..., 0.2544, -1.0975, -0.6509], [ 0.7716, 0.2966, 0.1699, ..., -1.2789, -0.5366, 0.3534], ..., [-1.2036, -1.3401, -0.3581, ..., 0.1999, 0.6540, -0.4159], [ 1.2155, -0.0542, 0.4923, ..., -0.1561, 0.9865, -0.6558], [ 1.0399, -1.5524, 0.0432, ..., -0.7237, -0.7161, 0.5026]], [[ 0.3194, -0.9475, 0.9975, ..., -0.7796, -2.1479, -0.6828], [-0.3623, -1.6566, 0.6783, ..., 2.4238, -0.3513, 1.6672], [ 0.3914, 0.2937, -0.2541, ..., 1.7687, -0.3865, -0.8186], ..., [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691]], [[-1.4615, 1.6652, -1.0368, ..., 0.6557, -0.1662, -0.8121], [-0.3623, -1.6566, 0.6783, ..., 2.4238, -0.3513, 1.6672], [-1.7808, 1.5748, 1.9841, ..., 0.1642, 1.0493, 0.2800], ..., [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691]]], grad_fn=<EmbeddingBackward>) torch.Size([10, 20, 512])

Linear transformation & 여러 head로 나누기

Multi-head attention 내에서 쓰이는 linear transformation matrix들을 정의합니다.

[9]
[10]
[11]
torch.Size([10, 20, 512]) torch.Size([10, 20, 512]) torch.Size([10, 20, 512])

Q, k, v를 num_head개의 차원 분할된 여러 vector로 만듭니다.

[12]
torch.Size([10, 20, 8, 64]) torch.Size([10, 20, 8, 64]) torch.Size([10, 20, 8, 64])
[13]
torch.Size([10, 8, 20, 64]) torch.Size([10, 8, 20, 64]) torch.Size([10, 8, 20, 64])

Scaled dot-product self-attention 구현

각 head에서 실행되는 self-attetion 과정입니다.

[14]
tensor([[[[0.0483, 0.0464, 0.0778, ..., 0.0614, 0.0614, 0.0614], [0.0362, 0.0656, 0.0156, ..., 0.0370, 0.0370, 0.0370], [0.0460, 0.0501, 0.0424, ..., 0.0623, 0.0623, 0.0623], ..., [0.0347, 0.0323, 0.0538, ..., 0.1034, 0.1034, 0.1034], [0.0347, 0.0323, 0.0538, ..., 0.1034, 0.1034, 0.1034], [0.0347, 0.0323, 0.0538, ..., 0.1034, 0.1034, 0.1034]], [[0.0404, 0.0403, 0.0285, ..., 0.0692, 0.0692, 0.0692], [0.0305, 0.0402, 0.0557, ..., 0.0377, 0.0377, 0.0377], [0.0372, 0.0459, 0.0572, ..., 0.0409, 0.0409, 0.0409], ..., [0.0666, 0.0519, 0.0567, ..., 0.0635, 0.0635, 0.0635], [0.0666, 0.0519, 0.0567, ..., 0.0635, 0.0635, 0.0635], [0.0666, 0.0519, 0.0567, ..., 0.0635, 0.0635, 0.0635]], [[0.0455, 0.0610, 0.0668, ..., 0.0576, 0.0576, 0.0576], [0.0735, 0.0627, 0.0385, ..., 0.0494, 0.0494, 0.0494], [0.0773, 0.0402, 0.0613, ..., 0.0504, 0.0504, 0.0504], ..., [0.0474, 0.0514, 0.0441, ..., 0.0479, 0.0479, 0.0479], [0.0474, 0.0514, 0.0441, ..., 0.0479, 0.0479, 0.0479], [0.0474, 0.0514, 0.0441, ..., 0.0479, 0.0479, 0.0479]], ..., [[0.0373, 0.0420, 0.0425, ..., 0.0405, 0.0405, 0.0405], [0.0395, 0.0286, 0.0691, ..., 0.0638, 0.0638, 0.0638], [0.0250, 0.0302, 0.0596, ..., 0.0698, 0.0698, 0.0698], ..., [0.0498, 0.0197, 0.0690, ..., 0.0537, 0.0537, 0.0537], [0.0498, 0.0197, 0.0690, ..., 0.0537, 0.0537, 0.0537], [0.0498, 0.0197, 0.0690, ..., 0.0537, 0.0537, 0.0537]], [[0.0387, 0.0249, 0.0319, ..., 0.0697, 0.0697, 0.0697], [0.0613, 0.0528, 0.0384, ..., 0.0322, 0.0322, 0.0322], [0.0498, 0.0539, 0.0328, ..., 0.0616, 0.0616, 0.0616], ..., [0.0426, 0.0712, 0.0455, ..., 0.0511, 0.0511, 0.0511], [0.0426, 0.0712, 0.0455, ..., 0.0511, 0.0511, 0.0511], [0.0426, 0.0712, 0.0455, ..., 0.0511, 0.0511, 0.0511]], [[0.0727, 0.0252, 0.0722, ..., 0.0559, 0.0559, 0.0559], [0.0863, 0.0512, 0.0345, ..., 0.0449, 0.0449, 0.0449], [0.0634, 0.0584, 0.0333, ..., 0.0503, 0.0503, 0.0503], ..., [0.0329, 0.0757, 0.0665, ..., 0.0410, 0.0410, 0.0410], [0.0329, 0.0757, 0.0665, ..., 0.0410, 0.0410, 0.0410], [0.0329, 0.0757, 0.0665, ..., 0.0410, 0.0410, 0.0410]]], [[[0.0768, 0.0889, 0.0462, ..., 0.0416, 0.0416, 0.0416], [0.0286, 0.0342, 0.0516, ..., 0.0555, 0.0555, 0.0555], [0.0185, 0.0133, 0.0208, ..., 0.0610, 0.0610, 0.0610], ..., [0.0208, 0.0198, 0.0146, ..., 0.0593, 0.0593, 0.0593], [0.0208, 0.0198, 0.0146, ..., 0.0593, 0.0593, 0.0593], [0.0208, 0.0198, 0.0146, ..., 0.0593, 0.0593, 0.0593]], [[0.0395, 0.0524, 0.0477, ..., 0.0518, 0.0518, 0.0518], [0.0460, 0.0686, 0.0540, ..., 0.0477, 0.0477, 0.0477], [0.0278, 0.0704, 0.0335, ..., 0.0522, 0.0522, 0.0522], ..., [0.0336, 0.0225, 0.0344, ..., 0.0573, 0.0573, 0.0573], [0.0336, 0.0225, 0.0344, ..., 0.0573, 0.0573, 0.0573], [0.0336, 0.0225, 0.0344, ..., 0.0573, 0.0573, 0.0573]], [[0.0422, 0.0351, 0.0672, ..., 0.0510, 0.0510, 0.0510], [0.0234, 0.0575, 0.0511, ..., 0.0536, 0.0536, 0.0536], [0.0304, 0.0461, 0.0541, ..., 0.0522, 0.0522, 0.0522], ..., [0.0769, 0.0466, 0.0611, ..., 0.0488, 0.0488, 0.0488], [0.0769, 0.0466, 0.0611, ..., 0.0488, 0.0488, 0.0488], [0.0769, 0.0466, 0.0611, ..., 0.0488, 0.0488, 0.0488]], ..., [[0.0275, 0.0409, 0.0720, ..., 0.0493, 0.0493, 0.0493], [0.0481, 0.0495, 0.0574, ..., 0.0490, 0.0490, 0.0490], [0.0561, 0.0507, 0.0705, ..., 0.0451, 0.0451, 0.0451], ..., [0.0825, 0.0692, 0.0783, ..., 0.0441, 0.0441, 0.0441], [0.0825, 0.0692, 0.0783, ..., 0.0441, 0.0441, 0.0441], [0.0825, 0.0692, 0.0783, ..., 0.0441, 0.0441, 0.0441]], [[0.0264, 0.0646, 0.0331, ..., 0.0523, 0.0523, 0.0523], [0.0760, 0.0303, 0.0844, ..., 0.0451, 0.0451, 0.0451], [0.0276, 0.0411, 0.0247, ..., 0.0574, 0.0574, 0.0574], ..., [0.0365, 0.0452, 0.0674, ..., 0.0510, 0.0510, 0.0510], [0.0365, 0.0452, 0.0674, ..., 0.0510, 0.0510, 0.0510], [0.0365, 0.0452, 0.0674, ..., 0.0510, 0.0510, 0.0510]], [[0.0303, 0.0351, 0.0208, ..., 0.0570, 0.0570, 0.0570], [0.0325, 0.0703, 0.0617, ..., 0.0491, 0.0491, 0.0491], [0.0692, 0.0625, 0.0669, ..., 0.0377, 0.0377, 0.0377], ..., [0.0641, 0.0561, 0.0607, ..., 0.0474, 0.0474, 0.0474], [0.0641, 0.0561, 0.0607, ..., 0.0474, 0.0474, 0.0474], [0.0641, 0.0561, 0.0607, ..., 0.0474, 0.0474, 0.0474]]], [[[0.0319, 0.0667, 0.1340, ..., 0.0258, 0.0258, 0.0258], [0.0427, 0.0423, 0.0473, ..., 0.0536, 0.0536, 0.0536], [0.0788, 0.0390, 0.0182, ..., 0.0467, 0.0467, 0.0467], ..., [0.0462, 0.0179, 0.0145, ..., 0.0786, 0.0786, 0.0786], [0.0462, 0.0179, 0.0145, ..., 0.0786, 0.0786, 0.0786], [0.0462, 0.0179, 0.0145, ..., 0.0786, 0.0786, 0.0786]], [[0.0364, 0.0986, 0.0557, ..., 0.0394, 0.0394, 0.0394], [0.0673, 0.0528, 0.0870, ..., 0.0468, 0.0468, 0.0468], [0.0513, 0.0916, 0.0370, ..., 0.0340, 0.0340, 0.0340], ..., [0.0393, 0.0612, 0.0221, ..., 0.0608, 0.0608, 0.0608], [0.0393, 0.0612, 0.0221, ..., 0.0608, 0.0608, 0.0608], [0.0393, 0.0612, 0.0221, ..., 0.0608, 0.0608, 0.0608]], [[0.0255, 0.0424, 0.0443, ..., 0.0660, 0.0660, 0.0660], [0.0531, 0.0754, 0.0386, ..., 0.0541, 0.0541, 0.0541], [0.0622, 0.0611, 0.0416, ..., 0.0516, 0.0516, 0.0516], ..., [0.0385, 0.0434, 0.0429, ..., 0.0470, 0.0470, 0.0470], [0.0385, 0.0434, 0.0429, ..., 0.0470, 0.0470, 0.0470], [0.0385, 0.0434, 0.0429, ..., 0.0470, 0.0470, 0.0470]], ..., [[0.0485, 0.0438, 0.0449, ..., 0.0447, 0.0447, 0.0447], [0.0331, 0.0283, 0.0588, ..., 0.0636, 0.0636, 0.0636], [0.0586, 0.0647, 0.0924, ..., 0.0431, 0.0431, 0.0431], ..., [0.0706, 0.0612, 0.0526, ..., 0.0476, 0.0476, 0.0476], [0.0706, 0.0612, 0.0526, ..., 0.0476, 0.0476, 0.0476], [0.0706, 0.0612, 0.0526, ..., 0.0476, 0.0476, 0.0476]], [[0.0420, 0.0601, 0.0400, ..., 0.0466, 0.0466, 0.0466], [0.0508, 0.0614, 0.0849, ..., 0.0376, 0.0376, 0.0376], [0.0326, 0.0251, 0.0395, ..., 0.0514, 0.0514, 0.0514], ..., [0.0369, 0.0344, 0.0416, ..., 0.0564, 0.0564, 0.0564], [0.0369, 0.0344, 0.0416, ..., 0.0564, 0.0564, 0.0564], [0.0369, 0.0344, 0.0416, ..., 0.0564, 0.0564, 0.0564]], [[0.0524, 0.0410, 0.0576, ..., 0.0572, 0.0572, 0.0572], [0.0561, 0.0960, 0.0704, ..., 0.0430, 0.0430, 0.0430], [0.0610, 0.0537, 0.0277, ..., 0.0351, 0.0351, 0.0351], ..., [0.0485, 0.0806, 0.0452, ..., 0.0426, 0.0426, 0.0426], [0.0485, 0.0806, 0.0452, ..., 0.0426, 0.0426, 0.0426], [0.0485, 0.0806, 0.0452, ..., 0.0426, 0.0426, 0.0426]]], ..., [[[0.0356, 0.0693, 0.0371, ..., 0.0553, 0.0690, 0.0497], [0.0185, 0.0895, 0.0409, ..., 0.0365, 0.0746, 0.0806], [0.0202, 0.0567, 0.0417, ..., 0.0596, 0.0598, 0.0458], ..., [0.0617, 0.0810, 0.0333, ..., 0.0863, 0.0659, 0.0325], [0.0805, 0.0325, 0.0351, ..., 0.0385, 0.0400, 0.0501], [0.0539, 0.0366, 0.0618, ..., 0.0364, 0.0500, 0.0466]], [[0.0478, 0.0542, 0.0370, ..., 0.0801, 0.0679, 0.0344], [0.0852, 0.0416, 0.0637, ..., 0.0481, 0.0530, 0.0445], [0.0601, 0.0613, 0.0408, ..., 0.0407, 0.0693, 0.0615], ..., [0.0350, 0.0598, 0.0460, ..., 0.0949, 0.0513, 0.0563], [0.0545, 0.0409, 0.0532, ..., 0.0597, 0.0392, 0.0411], [0.0474, 0.0552, 0.0396, ..., 0.0646, 0.0546, 0.0326]], [[0.0484, 0.0461, 0.0176, ..., 0.0607, 0.0507, 0.0382], [0.0985, 0.0508, 0.0593, ..., 0.0314, 0.0546, 0.0861], [0.0388, 0.0508, 0.0483, ..., 0.0543, 0.0564, 0.0994], ..., [0.0504, 0.0384, 0.0708, ..., 0.0208, 0.0460, 0.0370], [0.0463, 0.0430, 0.0450, ..., 0.0732, 0.0456, 0.0704], [0.0380, 0.0473, 0.0382, ..., 0.0436, 0.0412, 0.0702]], ..., [[0.0574, 0.0286, 0.0633, ..., 0.0606, 0.0433, 0.0666], [0.0491, 0.0768, 0.0442, ..., 0.0694, 0.0412, 0.0423], [0.0570, 0.0356, 0.0540, ..., 0.0382, 0.0603, 0.0340], ..., [0.0669, 0.0394, 0.0790, ..., 0.0497, 0.0385, 0.0524], [0.0518, 0.0976, 0.0309, ..., 0.0465, 0.0353, 0.0399], [0.0435, 0.0437, 0.0476, ..., 0.0469, 0.0628, 0.0385]], [[0.0511, 0.0343, 0.0497, ..., 0.0592, 0.0426, 0.0299], [0.0433, 0.0324, 0.0599, ..., 0.0774, 0.0478, 0.0462], [0.0553, 0.0347, 0.0396, ..., 0.0414, 0.0329, 0.0392], ..., [0.0405, 0.0380, 0.0432, ..., 0.0903, 0.0603, 0.0343], [0.0414, 0.0506, 0.0374, ..., 0.0455, 0.0491, 0.0531], [0.0447, 0.0297, 0.0492, ..., 0.0441, 0.0549, 0.0418]], [[0.0401, 0.0182, 0.0629, ..., 0.0664, 0.0572, 0.0494], [0.0447, 0.0362, 0.0564, ..., 0.0397, 0.0405, 0.0733], [0.0449, 0.0701, 0.0348, ..., 0.0376, 0.0419, 0.0580], ..., [0.0698, 0.0570, 0.0360, ..., 0.0525, 0.0400, 0.0683], [0.0510, 0.0493, 0.0370, ..., 0.0363, 0.0471, 0.0376], [0.0576, 0.0597, 0.0692, ..., 0.0397, 0.0494, 0.0448]]], [[[0.0406, 0.0412, 0.0219, ..., 0.0877, 0.0877, 0.0877], [0.0338, 0.0559, 0.0529, ..., 0.0806, 0.0806, 0.0806], [0.0506, 0.0337, 0.0514, ..., 0.0318, 0.0318, 0.0318], ..., [0.0344, 0.0460, 0.0434, ..., 0.1100, 0.1100, 0.1100], [0.0344, 0.0460, 0.0434, ..., 0.1100, 0.1100, 0.1100], [0.0344, 0.0460, 0.0434, ..., 0.1100, 0.1100, 0.1100]], [[0.0601, 0.0350, 0.0664, ..., 0.0631, 0.0631, 0.0631], [0.0411, 0.0600, 0.0305, ..., 0.0237, 0.0237, 0.0237], [0.0283, 0.0500, 0.1090, ..., 0.0510, 0.0510, 0.0510], ..., [0.0614, 0.0681, 0.0461, ..., 0.0622, 0.0622, 0.0622], [0.0614, 0.0681, 0.0461, ..., 0.0622, 0.0622, 0.0622], [0.0614, 0.0681, 0.0461, ..., 0.0622, 0.0622, 0.0622]], [[0.0363, 0.0662, 0.0641, ..., 0.0387, 0.0387, 0.0387], [0.0313, 0.0288, 0.0348, ..., 0.0297, 0.0297, 0.0297], [0.0246, 0.1005, 0.0401, ..., 0.0490, 0.0490, 0.0490], ..., [0.0643, 0.0340, 0.0427, ..., 0.0503, 0.0503, 0.0503], [0.0643, 0.0340, 0.0427, ..., 0.0503, 0.0503, 0.0503], [0.0643, 0.0340, 0.0427, ..., 0.0503, 0.0503, 0.0503]], ..., [[0.0553, 0.0274, 0.0333, ..., 0.0620, 0.0620, 0.0620], [0.0565, 0.0468, 0.0630, ..., 0.0389, 0.0389, 0.0389], [0.0441, 0.0327, 0.0769, ..., 0.0423, 0.0423, 0.0423], ..., [0.0263, 0.1001, 0.0387, ..., 0.0492, 0.0492, 0.0492], [0.0263, 0.1001, 0.0387, ..., 0.0492, 0.0492, 0.0492], [0.0263, 0.1001, 0.0387, ..., 0.0492, 0.0492, 0.0492]], [[0.0537, 0.0480, 0.0862, ..., 0.0379, 0.0379, 0.0379], [0.0315, 0.0965, 0.0714, ..., 0.0440, 0.0440, 0.0440], [0.0546, 0.0409, 0.0454, ..., 0.0412, 0.0412, 0.0412], ..., [0.0292, 0.0569, 0.0452, ..., 0.0497, 0.0497, 0.0497], [0.0292, 0.0569, 0.0452, ..., 0.0497, 0.0497, 0.0497], [0.0292, 0.0569, 0.0452, ..., 0.0497, 0.0497, 0.0497]], [[0.0473, 0.0449, 0.0630, ..., 0.0461, 0.0461, 0.0461], [0.0498, 0.0690, 0.0543, ..., 0.0436, 0.0436, 0.0436], [0.0312, 0.0324, 0.0388, ..., 0.0469, 0.0469, 0.0469], ..., [0.0875, 0.0861, 0.0373, ..., 0.0404, 0.0404, 0.0404], [0.0875, 0.0861, 0.0373, ..., 0.0404, 0.0404, 0.0404], [0.0875, 0.0861, 0.0373, ..., 0.0404, 0.0404, 0.0404]]], [[[0.0968, 0.0516, 0.0518, ..., 0.0462, 0.0462, 0.0462], [0.0517, 0.0533, 0.0306, ..., 0.0768, 0.0768, 0.0768], [0.0638, 0.0841, 0.0273, ..., 0.0418, 0.0418, 0.0418], ..., [0.0317, 0.0387, 0.0320, ..., 0.0924, 0.0924, 0.0924], [0.0317, 0.0387, 0.0320, ..., 0.0924, 0.0924, 0.0924], [0.0317, 0.0387, 0.0320, ..., 0.0924, 0.0924, 0.0924]], [[0.0467, 0.0609, 0.0833, ..., 0.0450, 0.0450, 0.0450], [0.0436, 0.0675, 0.0896, ..., 0.0267, 0.0267, 0.0267], [0.0476, 0.0822, 0.0771, ..., 0.0561, 0.0561, 0.0561], ..., [0.0347, 0.0670, 0.0367, ..., 0.0612, 0.0612, 0.0612], [0.0347, 0.0670, 0.0367, ..., 0.0612, 0.0612, 0.0612], [0.0347, 0.0670, 0.0367, ..., 0.0612, 0.0612, 0.0612]], [[0.0339, 0.0394, 0.0391, ..., 0.0499, 0.0499, 0.0499], [0.0310, 0.0318, 0.0619, ..., 0.0328, 0.0328, 0.0328], [0.0831, 0.0336, 0.0379, ..., 0.0541, 0.0541, 0.0541], ..., [0.0359, 0.0349, 0.0613, ..., 0.0516, 0.0516, 0.0516], [0.0359, 0.0349, 0.0613, ..., 0.0516, 0.0516, 0.0516], [0.0359, 0.0349, 0.0613, ..., 0.0516, 0.0516, 0.0516]], ..., [[0.0544, 0.0314, 0.0356, ..., 0.0515, 0.0515, 0.0515], [0.0799, 0.0470, 0.0319, ..., 0.0391, 0.0391, 0.0391], [0.0408, 0.0561, 0.0754, ..., 0.0346, 0.0346, 0.0346], ..., [0.0207, 0.0974, 0.0423, ..., 0.0479, 0.0479, 0.0479], [0.0207, 0.0974, 0.0423, ..., 0.0479, 0.0479, 0.0479], [0.0207, 0.0974, 0.0423, ..., 0.0479, 0.0479, 0.0479]], [[0.0487, 0.0639, 0.0416, ..., 0.0497, 0.0497, 0.0497], [0.0415, 0.0952, 0.0803, ..., 0.0434, 0.0434, 0.0434], [0.0421, 0.0440, 0.0253, ..., 0.0651, 0.0651, 0.0651], ..., [0.0538, 0.0644, 0.0421, ..., 0.0563, 0.0563, 0.0563], [0.0538, 0.0644, 0.0421, ..., 0.0563, 0.0563, 0.0563], [0.0538, 0.0644, 0.0421, ..., 0.0563, 0.0563, 0.0563]], [[0.0271, 0.0523, 0.1258, ..., 0.0414, 0.0414, 0.0414], [0.0624, 0.0719, 0.0781, ..., 0.0454, 0.0454, 0.0454], [0.0386, 0.0539, 0.0486, ..., 0.0528, 0.0528, 0.0528], ..., [0.0546, 0.0886, 0.1174, ..., 0.0416, 0.0416, 0.0416], [0.0546, 0.0886, 0.1174, ..., 0.0416, 0.0416, 0.0416], [0.0546, 0.0886, 0.1174, ..., 0.0416, 0.0416, 0.0416]]]], grad_fn=<SoftmaxBackward>) torch.Size([10, 8, 20, 20])
[15]
torch.Size([10, 8, 20, 64])

각 head의 결과물 병합

각 head의 결과물을 concat하고 동일 차원으로 linear transformation합니다.

[16]
torch.Size([10, 20, 512])
[17]
tensor([[[-1.1352e-01, -1.9139e-01, 3.4395e-02, ..., 4.0102e-02, -2.2987e-01, 1.9529e-01], [-8.6867e-02, -2.1087e-01, -1.0989e-02, ..., 4.2762e-02, -1.9277e-01, 8.8230e-02], [-1.0772e-01, -2.3987e-01, 3.8831e-02, ..., -9.8586e-03, -2.1114e-01, 1.1364e-01], ..., [-1.4770e-01, -2.5519e-01, 3.5802e-02, ..., 2.1453e-02, -1.9920e-01, 1.3214e-01], [-1.4770e-01, -2.5519e-01, 3.5802e-02, ..., 2.1453e-02, -1.9920e-01, 1.3214e-01], [-1.4770e-01, -2.5519e-01, 3.5802e-02, ..., 2.1453e-02, -1.9920e-01, 1.3214e-01]], [[-1.9061e-01, -5.3132e-01, 7.5893e-02, ..., 1.5970e-01, -5.5873e-01, 5.3326e-01], [-2.2175e-01, -5.3580e-01, 1.1941e-01, ..., 1.4892e-01, -5.1877e-01, 5.3374e-01], [-2.5074e-01, -6.1631e-01, 1.0151e-01, ..., 1.8491e-01, -5.1544e-01, 5.0850e-01], ..., [-2.2835e-01, -5.4975e-01, 9.6207e-02, ..., 1.9166e-01, -5.5457e-01, 5.4181e-01], [-2.2835e-01, -5.4975e-01, 9.6207e-02, ..., 1.9166e-01, -5.5457e-01, 5.4181e-01], [-2.2835e-01, -5.4975e-01, 9.6207e-02, ..., 1.9166e-01, -5.5457e-01, 5.4181e-01]], [[-3.5870e-03, -2.8132e-01, 8.0012e-02, ..., 1.3743e-01, -3.1456e-01, 2.6330e-01], [-8.5198e-02, -3.6469e-01, 6.6780e-02, ..., 1.6005e-01, -3.5290e-01, 2.1697e-01], [-3.4478e-02, -3.8513e-01, 9.0698e-02, ..., 1.4251e-01, -2.7272e-01, 1.5694e-01], ..., [-1.2265e-01, -3.9496e-01, 6.3260e-02, ..., 1.1824e-01, -3.8367e-01, 2.5895e-01], [-1.2265e-01, -3.9496e-01, 6.3260e-02, ..., 1.1824e-01, -3.8367e-01, 2.5895e-01], [-1.2265e-01, -3.9496e-01, 6.3260e-02, ..., 1.1824e-01, -3.8367e-01, 2.5895e-01]], ..., [[ 3.9329e-02, -1.6083e-02, -1.3860e-01, ..., 2.2836e-02, 1.0380e-02, 1.6353e-01], [ 3.5097e-02, -8.2941e-03, -1.9275e-01, ..., -7.4558e-03, 2.7378e-02, 1.5106e-01], [ 1.3097e-02, 1.3226e-02, -1.8156e-01, ..., 3.8179e-02, 2.5640e-02, 1.0779e-01], ..., [ 4.8920e-02, 4.3918e-02, -1.8756e-01, ..., 2.1911e-02, 4.1894e-02, 1.4420e-01], [ 3.8124e-02, -3.4729e-02, -1.4246e-01, ..., 2.2699e-02, 8.5502e-02, 1.5101e-01], [ 5.7281e-02, 4.1677e-02, -1.5282e-01, ..., 5.5112e-02, -1.1841e-03, 1.2372e-01]], [[-1.8894e-02, -2.2604e-01, -6.2761e-03, ..., 2.4829e-02, -2.0479e-01, 8.3719e-02], [-1.0974e-02, -2.0397e-01, -2.7611e-03, ..., -1.2597e-02, -1.3351e-01, 5.0224e-02], [ 4.9551e-02, -1.9723e-01, -6.5305e-02, ..., 6.6988e-02, -1.7278e-01, 1.6469e-01], ..., [-1.4522e-02, -2.6608e-01, 1.3490e-04, ..., 4.5861e-02, -2.0047e-01, 1.3786e-01], [-1.4522e-02, -2.6608e-01, 1.3490e-04, ..., 4.5861e-02, -2.0047e-01, 1.3786e-01], [-1.4522e-02, -2.6608e-01, 1.3490e-04, ..., 4.5861e-02, -2.0047e-01, 1.3786e-01]], [[ 1.9836e-02, -1.5901e-01, -5.8236e-02, ..., 1.1571e-01, -1.9807e-01, 3.1994e-01], [ 3.1970e-02, -1.6160e-01, -4.8095e-02, ..., 3.1911e-02, -1.3050e-01, 3.1769e-01], [ 3.0805e-02, -1.4795e-01, -8.4109e-02, ..., 1.0025e-01, -1.9486e-01, 3.1365e-01], ..., [ 2.7737e-02, -1.9884e-01, -2.9289e-02, ..., 1.3317e-01, -2.6549e-01, 3.4827e-01], [ 2.7737e-02, -1.9884e-01, -2.9289e-02, ..., 1.3317e-01, -2.6549e-01, 3.4827e-01], [ 2.7737e-02, -1.9884e-01, -2.9289e-02, ..., 1.3317e-01, -2.6549e-01, 3.4827e-01]]], grad_fn=<AddBackward0>) torch.Size([10, 20, 512])

전체 코드

위의 과정을 모두 합쳐 하나의 Multi-head attention 모듈을 구현하겠습니다.

[18]
[19]
[20]
tensor([[[ 9.2516e-03, 1.7369e-01, 1.6916e-03, ..., 1.1210e-01, -1.6726e-02, 3.4743e-01], [-3.6782e-02, 1.2917e-01, -4.9187e-02, ..., 1.4615e-01, -3.5719e-02, 3.4785e-01], [ 3.5305e-03, 1.6074e-01, -8.8263e-02, ..., 1.8464e-01, 4.2831e-02, 4.0534e-01], ..., [-2.8834e-02, 1.2947e-01, -3.5194e-02, ..., 1.8476e-01, -3.0066e-05, 3.2937e-01], [-2.8834e-02, 1.2947e-01, -3.5194e-02, ..., 1.8476e-01, -3.0066e-05, 3.2937e-01], [-2.8834e-02, 1.2947e-01, -3.5194e-02, ..., 1.8476e-01, -3.0066e-05, 3.2937e-01]], [[ 1.4557e-01, -1.0484e-01, -5.2524e-02, ..., 6.9377e-01, 6.0970e-02, 4.6097e-01], [ 1.5773e-01, -8.2220e-02, -7.7061e-02, ..., 7.2866e-01, 4.4147e-02, 4.3781e-01], [ 1.3460e-01, -9.5307e-02, -9.3904e-02, ..., 7.3514e-01, 7.2221e-02, 4.4110e-01], ..., [ 1.7378e-01, -1.1004e-01, -1.2651e-01, ..., 7.8337e-01, 6.8945e-02, 3.9528e-01], [ 1.7378e-01, -1.1004e-01, -1.2651e-01, ..., 7.8337e-01, 6.8945e-02, 3.9528e-01], [ 1.7378e-01, -1.1004e-01, -1.2651e-01, ..., 7.8337e-01, 6.8945e-02, 3.9528e-01]], [[ 1.0813e-01, -1.9237e-01, 3.3813e-02, ..., 3.8729e-01, 1.4551e-01, 3.0486e-01], [ 1.1949e-01, -1.7870e-01, 1.7301e-02, ..., 4.1831e-01, 1.1177e-01, 2.9937e-01], [ 1.2946e-01, -1.1937e-01, -2.8948e-02, ..., 5.2200e-01, 1.1423e-01, 3.4231e-01], ..., [ 7.3729e-02, -1.1033e-01, -8.8151e-02, ..., 3.9152e-01, 1.6017e-01, 2.6974e-01], [ 7.3729e-02, -1.1033e-01, -8.8151e-02, ..., 3.9152e-01, 1.6017e-01, 2.6974e-01], [ 7.3729e-02, -1.1033e-01, -8.8151e-02, ..., 3.9152e-01, 1.6017e-01, 2.6974e-01]], ..., [[ 1.1427e-01, -8.0809e-02, -9.1286e-02, ..., 1.3664e-02, 1.7932e-01, 1.6808e-02], [ 1.0275e-01, -9.2001e-02, -1.2701e-01, ..., 4.8521e-03, 1.9850e-01, 6.6440e-02], [ 5.8104e-02, -5.4810e-02, -1.3720e-01, ..., 5.1057e-02, 1.2398e-01, 3.6784e-02], ..., [ 1.1326e-01, -3.4439e-02, -7.9127e-02, ..., 2.4230e-02, 1.2355e-01, 3.7422e-02], [ 5.5170e-02, -1.9381e-02, -8.2321e-02, ..., 2.8540e-02, 1.3763e-01, 5.2429e-02], [ 1.0601e-01, -3.5269e-02, -9.8664e-02, ..., 3.1459e-03, 1.2400e-01, 6.3038e-02]], [[ 3.0379e-02, 4.2547e-02, -5.1137e-02, ..., 1.6373e-01, 3.6856e-02, 1.4134e-01], [-1.8070e-02, 7.9446e-02, 3.5455e-02, ..., 1.3126e-01, -6.4744e-02, 1.3485e-01], [ 7.5835e-02, 8.2923e-02, -2.5474e-02, ..., 2.0878e-01, 9.6988e-02, 5.7338e-02], ..., [ 2.6591e-02, 1.3556e-01, -3.2525e-02, ..., 1.7950e-01, 3.0757e-02, 1.0859e-01], [ 2.6591e-02, 1.3556e-01, -3.2525e-02, ..., 1.7950e-01, 3.0757e-02, 1.0859e-01], [ 2.6591e-02, 1.3556e-01, -3.2525e-02, ..., 1.7950e-01, 3.0757e-02, 1.0859e-01]], [[-8.6823e-03, -4.3011e-02, -1.4872e-01, ..., 3.8146e-01, -3.7269e-02, 2.0649e-01], [-3.7705e-02, -4.4422e-02, -7.4764e-02, ..., 2.4636e-01, -1.0249e-01, 1.8855e-01], [-3.8483e-02, -2.6246e-02, -5.2669e-02, ..., 2.3631e-01, 2.0639e-03, 1.6390e-01], ..., [-8.9017e-03, -7.5137e-02, -1.4088e-01, ..., 3.3598e-01, -2.4982e-02, 1.8044e-01], [-8.9017e-03, -7.5137e-02, -1.4088e-01, ..., 3.3598e-01, -2.4982e-02, 1.8044e-01], [-8.9017e-03, -7.5137e-02, -1.4088e-01, ..., 3.3598e-01, -2.4982e-02, 1.8044e-01]]], grad_fn=<AddBackward0>) torch.Size([10, 20, 512])



(실습 8강) Masked Multi-head Attention 구현


mask 작업. 미래를 못 보게.



##8. Masked Multi-head Attention 1. Masked Multi-head Attention 구현. 2. Encoder-Decoder Attention 구현.

필요 패키지 import

[1]

데이터 전처리

데이터의 값과 형태를 좀 더 명확하게 보기 위해 sample을 줄이겠습니다.

[2]
[3]
[4]
100%|██████████| 5/5 [00:00<00:00, 3296.37it/s]
Maximum sequence length: 10
[5]
[[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
 [77, 65, 51, 77, 19, 15, 35, 19, 23, 0]]

Hyperparameter 세팅 및 embedding

[6]
[7]
[8]
tensor([[[ 2.5978e-02, -1.1719e+00, -5.6547e-01, 1.0690e+00, -7.4584e-01, -1.0695e+00, 1.4428e+00, -2.7004e+00], [-4.7617e-01, -1.3327e+00, 1.9251e+00, -6.8176e-01, 7.5115e-02, 5.3887e-01, 2.2054e-01, -2.0816e-01], [-8.6807e-01, 1.1268e+00, -7.2726e-01, -1.0275e+00, -3.0366e-01, 1.2544e+00, -7.0513e-02, -1.0134e+00], [-1.2948e+00, -2.5417e+00, -2.5985e-01, -3.3389e-01, 2.0048e-02, -1.6515e-01, -7.6054e-01, 1.1995e+00], [-1.1619e+00, -1.7698e+00, -5.5598e-01, -2.6992e-01, 1.3043e+00, -2.6215e-01, -6.2565e-01, -3.4484e-01], [-1.4553e+00, 7.6459e-01, -4.2104e-01, -5.1377e-01, 8.8455e-01, -1.5364e+00, 9.5698e-02, -1.2962e+00], [ 1.4414e+00, 6.7954e-01, 1.6368e-01, 6.5510e-01, 1.9676e-01, 2.7868e-01, 1.1996e-02, -7.4251e-01], [-4.7617e-01, -1.3327e+00, 1.9251e+00, -6.8176e-01, 7.5115e-02, 5.3887e-01, 2.2054e-01, -2.0816e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01]], [[ 1.3878e-01, -6.1272e-01, -9.2627e-01, 8.2591e-01, -4.9490e-01, 1.4858e+00, 3.7874e-01, 1.6428e-01], [ 1.6573e+00, -1.2150e+00, -1.8417e-01, 6.2360e-01, 6.1281e-01, -2.2841e-03, 8.1279e-01, 2.9292e-01], [ 6.9719e-01, 3.5959e-01, 1.0445e+00, 1.2747e+00, 2.3077e+00, 5.2847e-01, 1.1980e+00, -6.0787e-01], [ 2.5983e+00, 2.8562e+00, 6.5606e-01, -2.2477e-01, 1.8020e-01, 1.8544e+00, 1.2822e+00, -1.0173e+00], [ 2.5266e-01, 1.1753e+00, -2.5657e-01, -1.7501e+00, 2.5095e+00, 1.4618e+00, 5.3141e-01, -1.0419e+00], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01]], [[ 6.4874e-01, -1.2360e+00, 6.8337e-01, 6.0631e-01, -1.6179e+00, -1.8304e+00, 1.1675e+00, -1.3559e+00], [ 1.0240e+00, -1.5537e+00, -9.4666e-01, -1.5513e+00, 2.9823e-02, -3.6872e-01, -1.4232e+00, -4.1756e-01], [ 3.0624e+00, -8.0635e-01, 2.0955e+00, 2.7434e-02, -1.0448e+00, -1.3578e+00, -1.2429e+00, -9.7899e-01], [ 5.8301e-01, 5.7118e-01, 8.3664e-02, -9.9143e-01, -5.9037e-01, 1.4771e-02, 7.2694e-01, -3.0060e-01], [-6.9838e-01, -3.6387e-01, -4.6559e-01, -2.0434e+00, -2.3196e+00, -9.8511e-01, -1.8809e-01, -5.3997e-01], [-1.0637e+00, 1.0115e+00, -1.3071e+00, -2.4907e-01, -2.4333e-02, -4.5905e-01, 9.4616e-01, 5.4789e-01], [ 7.7480e-01, -3.0079e-01, -1.7079e-01, 6.4207e-01, -8.1697e-02, 1.4789e+00, 7.9172e-01, -5.1938e-01], [ 5.0799e-01, 8.9652e-01, -1.6079e+00, -1.1147e+00, 1.5580e-01, 8.5131e-01, -7.9493e-01, 1.8839e+00], [-2.8777e-01, 4.7038e-01, 1.1657e+00, -3.4352e-01, 2.4759e-01, 1.7312e+00, -5.9322e-01, 2.5661e+00], [-6.4382e-01, 7.6634e-01, -2.5152e-02, -3.9127e-01, 3.1379e-02, 1.0803e+00, -2.6616e-01, -9.6649e-02]], [[ 3.9309e-01, 5.3615e-01, 1.4154e+00, 1.2089e+00, 1.5527e+00, 1.2730e+00, 4.5496e-01, 6.8353e-01], [ 5.6372e-01, -1.1905e+00, 7.8466e-01, -9.8275e-01, -1.4256e+00, -1.4576e-01, -9.5380e-02, -1.5898e-01], [ 1.5278e+00, 8.1257e-01, 6.3651e-01, 7.1092e-01, -4.2330e-02, 2.6004e-01, -6.3720e-01, 9.4828e-01], [-8.6807e-01, 1.1268e+00, -7.2726e-01, -1.0275e+00, -3.0366e-01, 1.2544e+00, -7.0513e-02, -1.0134e+00], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01]], [[ 1.8276e+00, 2.1958e+00, 7.5264e-02, -1.2217e-03, 1.6027e-01, -4.3237e-01, 1.4135e-01, -9.1643e-01], [ 5.8301e-01, 5.7118e-01, 8.3664e-02, -9.9143e-01, -5.9037e-01, 1.4771e-02, 7.2694e-01, -3.0060e-01], [ 6.9719e-01, 3.5959e-01, 1.0445e+00, 1.2747e+00, 2.3077e+00, 5.2847e-01, 1.1980e+00, -6.0787e-01], [ 1.8276e+00, 2.1958e+00, 7.5264e-02, -1.2217e-03, 1.6027e-01, -4.3237e-01, 1.4135e-01, -9.1643e-01], [-1.4659e-01, -1.5753e+00, 2.2311e+00, -1.0745e+00, 5.2471e-03, -4.5582e-01, -4.2744e-01, -4.0704e-01], [-2.0087e-01, -1.3524e+00, 9.0261e-01, 1.3093e-01, -9.6100e-02, -5.0534e-02, 1.4622e+00, -9.9551e-01], [ 6.4874e-01, -1.2360e+00, 6.8337e-01, 6.0631e-01, -1.6179e+00, -1.8304e+00, 1.1675e+00, -1.3559e+00], [-1.4659e-01, -1.5753e+00, 2.2311e+00, -1.0745e+00, 5.2471e-03, -4.5582e-01, -4.2744e-01, -4.0704e-01], [ 2.1318e-01, 8.9759e-02, 1.1890e+00, -9.0741e-01, -2.3283e+00, 8.3807e-01, -2.7013e+00, -1.0480e+00], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01]]], grad_fn=<EmbeddingBackward>) torch.Size([5, 10, 8])

Mask 구축

True는 attention이 적용될 부분, False는 masking될 자리입니다.

[9]
tensor([[[ True, True, True, True, True, True, True, True, False, False]], [[ True, True, True, True, True, False, False, False, False, False]], [[ True, True, True, True, True, True, True, True, True, True]], [[ True, True, True, True, False, False, False, False, False, False]], [[ True, True, True, True, True, True, True, True, True, False]]]) torch.Size([5, 1, 10])
[10]
tensor([[[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, True, False, False, False, False], [ True, True, True, True, True, True, True, False, False, False], [ True, True, True, True, True, True, True, True, False, False], [ True, True, True, True, True, True, True, True, True, False], [ True, True, True, True, True, True, True, True, True, True]]]) torch.Size([1, 10, 10])
[11]
tensor([[[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, True, False, False, False, False], [ True, True, True, True, True, True, True, False, False, False], [ True, True, True, True, True, True, True, True, False, False], [ True, True, True, True, True, True, True, True, False, False], [ True, True, True, True, True, True, True, True, False, False]], [[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False]], [[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, True, False, False, False, False], [ True, True, True, True, True, True, True, False, False, False], [ True, True, True, True, True, True, True, True, False, False], [ True, True, True, True, True, True, True, True, True, False], [ True, True, True, True, True, True, True, True, True, True]], [[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False]], [[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, True, False, False, False, False], [ True, True, True, True, True, True, True, False, False, False], [ True, True, True, True, True, True, True, True, False, False], [ True, True, True, True, True, True, True, True, True, False], [ True, True, True, True, True, True, True, True, True, False]]]) torch.Size([5, 10, 10])

Linear transformation & 여러 head로 나누기

[12]
[13]
torch.Size([5, 2, 10, 4]) torch.Size([5, 2, 10, 4]) torch.Size([5, 2, 10, 4])

Masking이 적용된 self-attention 구현

[14]
[15]
tensor([[[[ 4.7637e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.0336e-01, 1.5477e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-6.7951e-02, 2.4981e-01, -3.1751e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.7413e-01, -6.5847e-02, 7.6012e-01, 5.9436e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.6878e-01, 1.2790e-01, -7.6460e-02, 1.0074e-01, 1.0696e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.2281e-01, 1.8613e-01, -4.6751e-01, 1.1887e-01, -1.5257e-01, -2.7174e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.0118e-02, 7.6968e-02, -1.7419e-01, 3.8869e-03, 5.2239e-02, 1.3832e-01, -1.2507e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.0336e-01, 1.5477e-02, -7.8466e-02, 7.1126e-02, 2.9099e-01, 2.6656e-01, -2.8050e-02, 1.5477e-02, -1.0000e+12, -1.0000e+12], [ 9.2820e-02, 1.7372e-01, 1.3079e-02, 1.8945e-01, -2.2064e-01, -2.8560e-01, -3.8994e-02, 1.7372e-01, -1.0000e+12, -1.0000e+12], [ 9.2820e-02, 1.7372e-01, 1.3079e-02, 1.8945e-01, -2.2064e-01, -2.8560e-01, -3.8994e-02, 1.7372e-01, -1.0000e+12, -1.0000e+12]], [[ 3.0696e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.3713e-02, -3.8888e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 2.5046e-01, 2.0949e-01, 2.6518e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-5.5559e-01, 2.2722e-01, -1.9210e-01, -3.1048e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-3.3784e-01, -1.2469e-01, 2.6978e-04, -5.4546e-01, -2.2231e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 2.3411e-01, -2.9132e-01, 9.4263e-04, 5.5149e-02, -5.4124e-02, 3.6974e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 5.9486e-01, 2.5571e-01, 4.5128e-01, 6.3356e-01, 3.5252e-01, 3.4562e-01, -4.0415e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.3713e-02, -3.8888e-01, 4.4088e-04, -5.7006e-01, -3.5372e-01, 9.7312e-02, -3.1540e-02, -3.8888e-01, -1.0000e+12, -1.0000e+12], [ 8.1846e-01, -7.7122e-02, 4.6495e-01, 4.2805e-01, 1.1123e-01, 5.6259e-01, -5.6964e-01, -7.7122e-02, -1.0000e+12, -1.0000e+12], [ 8.1846e-01, -7.7122e-02, 4.6495e-01, 4.2805e-01, 1.1123e-01, 5.6259e-01, -5.6964e-01, -7.7122e-02, -1.0000e+12, -1.0000e+12]]], [[[-1.3960e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.6457e-01, -1.8845e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.2020e-01, 7.4878e-01, 5.0221e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 9.2887e-02, 4.7570e-01, 2.9805e-01, -1.8933e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.2282e-01, 7.2596e-01, 2.9831e-01, -2.5718e-01, -1.7101e+00, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.1973e-01, 7.2967e-03, -1.9194e-01, -3.9124e-02, -5.8100e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.1973e-01, 7.2967e-03, -1.9194e-01, -3.9124e-02, -5.8100e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.1973e-01, 7.2967e-03, -1.9194e-01, -3.9124e-02, -5.8100e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.1973e-01, 7.2967e-03, -1.9194e-01, -3.9124e-02, -5.8100e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.1973e-01, 7.2967e-03, -1.9194e-01, -3.9124e-02, -5.8100e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12]], [[ 4.4586e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 2.4805e-01, -1.9373e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-3.4653e-01, -6.7971e-01, -6.4486e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.6247e-02, -1.0307e+00, -2.2533e+00, -1.7807e+00, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-3.3956e-01, -1.1370e+00, -1.2380e+00, -5.0333e-01, -2.0255e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.0716e-02, -6.5803e-01, -1.4617e+00, -1.1117e+00, -8.4180e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.0716e-02, -6.5803e-01, -1.4617e+00, -1.1117e+00, -8.4180e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.0716e-02, -6.5803e-01, -1.4617e+00, -1.1117e+00, -8.4180e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.0716e-02, -6.5803e-01, -1.4617e+00, -1.1117e+00, -8.4180e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.0716e-02, -6.5803e-01, -1.4617e+00, -1.1117e+00, -8.4180e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12]]], [[[-5.5858e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.3713e+00, 3.1185e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.6621e+00, 9.3574e-01, -5.5574e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.2074e-01, -1.0291e-01, -1.9896e-01, 4.5147e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.3240e+00, 3.3535e-01, -7.6370e-01, 1.4674e-01, 2.4826e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.3187e-01, -4.9049e-01, -2.7165e-01, 4.5427e-02, 2.6163e-01, 2.2556e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.9720e-01, -3.9120e-01, -2.4132e-01, -9.5560e-02, -1.1611e-01, 8.7099e-02, -1.1915e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-6.2815e-01, -6.7362e-01, -1.1487e+00, 2.2814e-01, 4.1933e-01, 6.7831e-01, 9.8324e-02, 7.0435e-01, -1.0000e+12, -1.0000e+12], [ 7.8095e-02, -3.5943e-01, -6.1643e-01, 6.0242e-02, 4.2268e-01, 3.6293e-01, 5.3020e-02, -2.5441e-02, 2.9505e-01, -1.0000e+12], [ 2.9748e-01, -5.2630e-01, -2.6692e-01, 1.8622e-02, 1.6033e-02, 1.8395e-01, 3.2762e-02, -2.4421e-01, 2.4621e-01, -3.8580e-02]], [[ 8.8483e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.8341e-01, 3.9749e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 9.5037e-01, 8.4297e-01, -1.1219e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 7.8095e-01, 5.1285e-01, 3.5245e-02, 2.2964e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 7.2292e-01, 8.4798e-01, 6.7199e-01, 1.4567e-01, 1.0198e+00, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.5190e-01, 3.6850e-01, 3.8342e-01, 1.0475e-01, 5.3902e-01, -4.8990e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.3664e-01, 1.0679e-01, -1.2078e-01, 7.9969e-02, 3.3342e-01, -8.0699e-02, 2.1315e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 2.3649e-01, 5.4584e-01, -1.9944e-01, 1.2506e-01, 8.4549e-01, -3.4700e-02, 3.1539e-01, -6.3801e-02, -1.0000e+12, -1.0000e+12], [-3.6674e-01, -1.4907e-01, -3.9890e-01, -5.3958e-02, -2.9095e-01, -5.9863e-02, 1.4150e-01, -4.4966e-02, 1.2219e-01, -1.0000e+12], [-1.2761e-03, 1.0568e-01, -1.9336e-01, 2.1486e-02, 2.4646e-01, 6.9932e-02, -3.2433e-03, -3.7407e-03, -1.6667e-02, 1.0430e-01]]], [[[-1.0113e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-2.0663e-01, -3.4611e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-2.9826e-02, -1.8627e-01, -2.0277e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.6179e-01, 1.0622e-01, 1.2711e-01, -3.1751e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12]], [[-3.0940e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-3.4044e-01, 5.0485e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-6.1672e-01, 5.7552e-01, -4.5673e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-3.5909e-01, 2.8637e-01, -2.8371e-01, 2.6518e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12]]], [[[-2.8563e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-2.4263e-02, 4.5147e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-5.8976e-01, -2.1452e-01, 5.0221e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-2.8563e-02, 3.2587e-02, 1.0890e-01, -2.8563e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.9209e-01, -6.2004e-02, -4.7334e-02, 1.9209e-01, -6.0634e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.8172e-01, -1.7188e-01, 2.1711e-01, -1.8172e-01, 2.6872e-01, 3.7224e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 2.4097e-01, -1.4395e-01, 4.1687e-02, 2.4097e-01, -5.4559e-02, -3.5269e-01, -5.5858e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.9209e-01, -6.2004e-02, -4.7334e-02, 1.9209e-01, -6.0634e-02, -2.6342e-01, -3.4811e-01, -6.0634e-02, -1.0000e+12, -1.0000e+12], [ 4.9793e-01, 9.6707e-02, -9.1980e-02, 4.9793e-01, -5.1975e-01, -9.9177e-01, -1.7358e+00, -5.1975e-01, 2.7955e-01, -1.0000e+12], [-1.8620e-01, 8.5844e-02, -1.9194e-01, -1.8620e-01, 3.6767e-02, 1.8844e-01, 3.0871e-01, 3.6767e-02, 2.6375e-01, -1.0000e+12]], [[-8.2484e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-4.9804e-01, 2.2964e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 6.2210e-02, 9.8633e-02, -6.4486e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.2484e-01, 3.5734e-01, -1.9387e+00, -8.2484e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.2728e-01, 1.2025e-02, -2.8253e-01, 1.2728e-01, -3.7051e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.2581e-01, 5.4310e-02, -1.8193e-01, 1.2581e-01, -2.2922e-01, -5.5643e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-2.3377e-01, 1.7875e-01, -7.1560e-01, -2.3377e-01, 4.8038e-01, 2.6586e-01, 8.8483e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.2728e-01, 1.2025e-02, -2.8253e-01, 1.2728e-01, -3.7051e-01, -2.2962e-01, -2.0968e-02, -3.7051e-01, -1.0000e+12, -1.0000e+12], [-5.0137e-01, 5.3234e-03, -2.2427e-01, -5.0137e-01, 7.3389e-01, 2.6270e-01, 8.9148e-02, 7.3389e-01, 6.9793e-01, -1.0000e+12], [-4.6310e-01, 2.6777e-01, -1.4617e+00, -4.6310e-01, 9.2073e-02, 1.7478e-01, 8.4084e-01, 9.2073e-02, 7.2260e-01, -1.0000e+12]]]], grad_fn=<MaskedFillBackward0>) torch.Size([5, 2, 10, 10])

-1* inf로 masking된 부분은 softmax 후 0이 됩니다.

[16]
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5958, 0.4042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3171, 0.4358, 0.2471, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1688, 0.1881, 0.4297, 0.2133, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2166, 0.2079, 0.1695, 0.2024, 0.2036, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2318, 0.2021, 0.1051, 0.1890, 0.1441, 0.1279, 0.0000, 0.0000, 0.0000, 0.0000], [0.1477, 0.1532, 0.1192, 0.1424, 0.1495, 0.1629, 0.1252, 0.0000, 0.0000, 0.0000], [0.1637, 0.1111, 0.1011, 0.1175, 0.1463, 0.1428, 0.1064, 0.1111, 0.0000, 0.0000], [0.1336, 0.1448, 0.1233, 0.1471, 0.0976, 0.0915, 0.1171, 0.1448, 0.0000, 0.0000], [0.1336, 0.1448, 0.1233, 0.1471, 0.0976, 0.0915, 0.1171, 0.1448, 0.0000, 0.0000]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5757, 0.4243, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3362, 0.3227, 0.3412, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1694, 0.3705, 0.2436, 0.2164, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1794, 0.2220, 0.2515, 0.1457, 0.2013, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1955, 0.1156, 0.1549, 0.1635, 0.1466, 0.2239, 0.0000, 0.0000, 0.0000, 0.0000], [0.1802, 0.1283, 0.1561, 0.1873, 0.1414, 0.1404, 0.0663, 0.0000, 0.0000, 0.0000], [0.1390, 0.1025, 0.1512, 0.0855, 0.1061, 0.1666, 0.1465, 0.1025, 0.0000, 0.0000], [0.2122, 0.0867, 0.1490, 0.1436, 0.1046, 0.1643, 0.0530, 0.0867, 0.0000, 0.0000], [0.2122, 0.0867, 0.1490, 0.1436, 0.1046, 0.1643, 0.0530, 0.0867, 0.0000, 0.0000]]], [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5060, 0.4940, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1906, 0.4544, 0.3551, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2248, 0.3297, 0.2760, 0.1695, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2402, 0.3595, 0.2344, 0.1345, 0.0315, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2910, 0.2129, 0.1745, 0.2033, 0.1182, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2910, 0.2129, 0.1745, 0.2033, 0.1182, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2910, 0.2129, 0.1745, 0.2033, 0.1182, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2910, 0.2129, 0.1745, 0.2033, 0.1182, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2910, 0.2129, 0.1745, 0.2033, 0.1182, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.6087, 0.3913, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.4067, 0.2915, 0.3018, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.6243, 0.2126, 0.0626, 0.1004, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2595, 0.1169, 0.1057, 0.2203, 0.2976, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3793, 0.2129, 0.0953, 0.1353, 0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3793, 0.2129, 0.0953, 0.1353, 0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3793, 0.2129, 0.0953, 0.1353, 0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3793, 0.2129, 0.0953, 0.1353, 0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3793, 0.2129, 0.0953, 0.1353, 0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]], [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1974, 0.8026, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0573, 0.7695, 0.1732, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2425, 0.2469, 0.2243, 0.2863, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0582, 0.3060, 0.1020, 0.2534, 0.2805, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2365, 0.0940, 0.1170, 0.1607, 0.1994, 0.1924, 0.0000, 0.0000, 0.0000, 0.0000], [0.2211, 0.1005, 0.1168, 0.1351, 0.1324, 0.1622, 0.1319, 0.0000, 0.0000, 0.0000], [0.0578, 0.0552, 0.0343, 0.1360, 0.1647, 0.2134, 0.1195, 0.2190, 0.0000, 0.0000], [0.1113, 0.0718, 0.0556, 0.1093, 0.1570, 0.1479, 0.1085, 0.1003, 0.1382, 0.0000], [0.1347, 0.0591, 0.0766, 0.1019, 0.1016, 0.1202, 0.1034, 0.0783, 0.1279, 0.0962]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.4467, 0.5533, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.4457, 0.4003, 0.1540, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3552, 0.2717, 0.1685, 0.2047, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2004, 0.2271, 0.1904, 0.1125, 0.2696, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1902, 0.1750, 0.1776, 0.1344, 0.2075, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000], [0.1473, 0.1429, 0.1138, 0.1392, 0.1793, 0.1185, 0.1590, 0.0000, 0.0000, 0.0000], [0.1201, 0.1636, 0.0777, 0.1074, 0.2208, 0.0916, 0.1299, 0.0889, 0.0000, 0.0000], [0.0855, 0.1064, 0.0828, 0.1170, 0.0923, 0.1163, 0.1422, 0.1180, 0.1395, 0.0000], [0.0961, 0.1069, 0.0793, 0.0983, 0.1231, 0.1032, 0.0959, 0.0958, 0.0946, 0.1068]]], [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5348, 0.4652, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3709, 0.3172, 0.3120, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3255, 0.2521, 0.2574, 0.1650, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3004, 0.6996, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1829, 0.6025, 0.2146, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1709, 0.3258, 0.1843, 0.3190, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]], [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.4827, 0.5173, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1840, 0.2677, 0.5483, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2375, 0.2525, 0.2725, 0.2375, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2304, 0.1787, 0.1814, 0.2304, 0.1790, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1281, 0.1293, 0.1908, 0.1281, 0.2009, 0.2228, 0.0000, 0.0000, 0.0000, 0.0000], [0.1906, 0.1297, 0.1562, 0.1906, 0.1419, 0.1053, 0.0857, 0.0000, 0.0000, 0.0000], [0.1579, 0.1225, 0.1243, 0.1579, 0.1226, 0.1001, 0.0920, 0.1226, 0.0000, 0.0000], [0.1967, 0.1317, 0.1091, 0.1967, 0.0711, 0.0444, 0.0211, 0.0711, 0.1581, 0.0000], [0.0872, 0.1144, 0.0867, 0.0872, 0.1090, 0.1268, 0.1430, 0.1090, 0.1367, 0.0000]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3257, 0.6743, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3952, 0.4099, 0.1949, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1789, 0.5835, 0.0587, 0.1789, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2402, 0.2141, 0.1595, 0.2402, 0.1460, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1923, 0.1790, 0.1413, 0.1923, 0.1348, 0.1604, 0.0000, 0.0000, 0.0000, 0.0000], [0.0919, 0.1389, 0.0568, 0.0919, 0.1877, 0.1515, 0.2813, 0.0000, 0.0000, 0.0000], [0.1579, 0.1407, 0.1048, 0.1579, 0.0960, 0.1105, 0.1362, 0.0960, 0.0000, 0.0000], [0.0523, 0.0868, 0.0690, 0.0523, 0.1798, 0.1122, 0.0944, 0.1798, 0.1735, 0.0000], [0.0596, 0.1238, 0.0220, 0.0596, 0.1038, 0.1128, 0.2195, 0.1038, 0.1951, 0.0000]]]], grad_fn=<SoftmaxBackward>) torch.Size([5, 2, 10, 10])
[17]
torch.Size([5, 2, 10, 4])

전체 코드

[18]
[19]
[20]
tensor([[[ 0.0144, 0.1497, 0.1135, 1.3092, 0.9384, 0.3099, -0.1766, -0.5622], [-0.2898, 0.0896, 0.1427, 0.8015, 0.4785, 0.0681, 0.0933, -0.4136], [-0.3018, -0.0423, -0.0087, 0.4880, 0.5833, -0.1454, 0.1864, -0.2006], [-0.4746, 0.2947, 0.0155, 0.5849, 0.3672, -0.1584, 0.2451, -0.2677], [-0.4331, 0.4161, -0.0199, 0.6722, 0.3977, -0.1417, 0.2226, -0.2819], [-0.2968, 0.3491, -0.1067, 0.6743, 0.5806, -0.0976, 0.1937, -0.2695], [-0.2873, 0.2195, -0.0550, 0.4832, 0.3601, -0.0188, 0.2223, -0.3254], [-0.3359, 0.1678, -0.0062, 0.4428, 0.3006, -0.0507, 0.2521, -0.2963], [-0.2698, 0.2142, -0.0160, 0.5737, 0.4235, -0.0383, 0.2095, -0.2744], [-0.2698, 0.2142, -0.0160, 0.5737, 0.4235, -0.0383, 0.2095, -0.2744]], [[-0.1768, -0.0337, 0.3418, 0.6498, 0.1639, 0.2566, 0.0251, -0.3522], [-0.0263, -0.0877, 0.3811, 0.5641, -0.0212, 0.4380, 0.0564, -0.4133], [-0.0217, -0.1415, 0.3251, 0.3769, -0.1307, 0.4615, 0.1472, -0.4424], [-0.0246, -0.4154, 0.2211, -0.0939, -0.2521, 0.4058, 0.2911, -0.3925], [-0.0876, -0.3316, 0.1590, -0.0801, -0.1808, 0.3044, 0.3026, -0.3570], [-0.0793, -0.3021, 0.1924, -0.0047, -0.1807, 0.3280, 0.2785, -0.3688], [-0.0793, -0.3021, 0.1924, -0.0047, -0.1807, 0.3280, 0.2785, -0.3688], [-0.0793, -0.3021, 0.1924, -0.0047, -0.1807, 0.3280, 0.2785, -0.3688], [-0.0793, -0.3021, 0.1924, -0.0047, -0.1807, 0.3280, 0.2785, -0.3688], [-0.0793, -0.3021, 0.1924, -0.0047, -0.1807, 0.3280, 0.2785, -0.3688]], [[ 0.0925, -0.2645, 0.2301, 0.8331, 0.6290, 0.4479, -0.0454, -0.4824], [-0.0289, 0.2087, 0.0284, 0.5535, 0.2514, 0.0563, 0.1032, -0.2719], [ 0.0734, 0.0038, 0.0775, 0.3035, -0.0467, 0.2196, 0.1020, -0.4090], [ 0.1093, -0.1755, 0.1103, 0.3176, 0.1820, 0.2193, 0.1487, -0.2609], [-0.0246, -0.0402, -0.0474, 0.3214, 0.3845, 0.0386, 0.1519, -0.2278], [-0.0277, -0.1558, 0.0172, 0.2918, 0.3786, 0.0965, 0.1986, -0.1995], [-0.0653, -0.2249, 0.0930, 0.1910, 0.1728, 0.1516, 0.2128, -0.2602], [-0.0925, -0.2369, 0.0983, 0.0908, 0.1850, -0.0380, 0.3283, 0.0092], [-0.1878, -0.3238, 0.1693, 0.0309, 0.0890, -0.0245, 0.3728, -0.0209], [-0.1793, -0.2081, 0.0681, 0.0427, 0.1342, -0.0089, 0.3275, -0.1020]], [[-0.0979, -0.3978, 0.2622, -0.3837, -0.6509, 0.4486, 0.4687, -0.3805], [-0.1681, -0.2912, 0.2294, -0.0610, -0.2760, 0.2384, 0.3482, -0.3007], [-0.0446, -0.3829, 0.2550, -0.1610, -0.3536, 0.2851, 0.3518, -0.2386], [-0.1189, -0.2787, 0.0283, -0.2147, -0.1060, 0.1434, 0.3567, -0.2334], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537]], [[ 0.2386, -0.8116, -0.0987, -0.8957, -0.2766, 0.2600, 0.5507, -0.0985], [-0.0426, -0.6154, -0.0750, -0.5514, -0.0237, 0.1282, 0.4729, -0.1658], [-0.0397, -0.4395, 0.0520, -0.2962, -0.1769, 0.3056, 0.4293, -0.3281], [ 0.0077, -0.5226, -0.0144, -0.4651, -0.1768, 0.2642, 0.4592, -0.2693], [-0.0573, -0.3984, -0.0025, -0.3258, -0.1420, 0.1969, 0.4423, -0.2659], [-0.1514, -0.2093, 0.0159, -0.0258, -0.0172, 0.2169, 0.3462, -0.3935], [-0.1374, -0.1000, -0.0064, 0.1789, 0.1111, 0.2921, 0.2488, -0.5185], [-0.1135, -0.2408, 0.0513, 0.0704, 0.0714, 0.2074, 0.3071, -0.3553], [-0.0981, -0.3065, 0.0201, -0.1530, -0.0656, 0.1096, 0.3669, -0.2580], [-0.0758, -0.3004, 0.0695, 0.0510, 0.1120, 0.1087, 0.3016, -0.2206]]], grad_fn=<AddBackward0>) torch.Size([5, 10, 8])

Encoder-Decoder attention

Query, key, value만 달라질 뿐 구현은 동일합니다.
Decoder에 들어갈 batch만 별도 구현하겠습니다.

[21]
100%|██████████| 5/5 [00:00<00:00, 4245.25it/s]
Maximum sequence length: 12
[22]
torch.Size([5, 10]) torch.Size([5, 12])
[23]
torch.Size([5, 10, 8]) torch.Size([5, 12, 8])

src_emb를 encoder에서 나온 결과, 그리고 trg_emb를 masked multi-head attention 후 결과로 가정합니다.

[24]
torch.Size([5, 2, 12, 4]) torch.Size([5, 2, 10, 4]) torch.Size([5, 2, 10, 4])
[25]
torch.Size([5, 2, 12, 10])
[26]
torch.Size([5, 2, 12, 4])

Masked multi-head attention 후 나온 결과와 동일한 shape를 가지며 이후 layer에서 전체 연산도 동일하게 진행됩니다.



==================================

과제 / 퀴즈

https://wikidocs.net/22592


Natural Language Processing

Assignment 4: Byte Pair Encoding

1. Introduction

  • 일반적으로 하나의 단어에 대해 하나의 embedding을 생성할 경우 out-of-vocabulary(OOV)라는 치명적인 문제를 갖게 됩니다. 학습 데이터에서 등장하지 않은 단어가 나오는 경우 Unknown token으로 처리해주어 모델의 입력으로 넣게 되면서 전체적으로 모델의 성능이 저하될 수 있습니다. 반면 모든 단어의 embedding을 만들기에는 필요한 embedding parameter의 수가 지나치게 많습니다. 이러한 문제를 해결하기 위해 컴퓨터가 이해하는 단어를 표현하는 데에 데이터 압축 알고리즘 중 하나인 byte pair encoding 기법을 적용한 sub-word tokenizaiton이라는 개념이 나타났습니다.
  • 본 과제에서는 byte pair encoding을 이용한 간단한 sub-word tokenizer를 구현해봅니다. 과제 노트북의 지시사항과 각 함수의 docstring과 논문의 3페이지 algorithm 1 참고하여 build_bpe 함수를 완성하고 모든 test case를 통과해주세요.
[1]
0 defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 8, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('e', 's'): 9, ('s', 't'): 9, ('t', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3}) ('e', 's') {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3} 1 defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'es'): 6, ('es', 't'): 9, ('t', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'es'): 3}) ('es', 't') {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3} 2 defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est'): 6, ('est', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3}) ('est', '</w>') {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3} 3 defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('l', 'o') {'lo w </w>': 5, 'lo w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3} 4 defaultdict(<class 'int'>, {('lo', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('lo', 'w') {'low </w>': 5, 'low e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3} 5 defaultdict(<class 'int'>, {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('n', 'e') {'low </w>': 5, 'low e r </w>': 2, 'ne w est</w>': 6, 'w i d est</w>': 3} 6 defaultdict(<class 'int'>, {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('ne', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('ne', 'w') {'low </w>': 5, 'low e r </w>': 2, 'new est</w>': 6, 'w i d est</w>': 3} 7 defaultdict(<class 'int'>, {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('new', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('new', 'est</w>') {'low </w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'w i d est</w>': 3} 8 defaultdict(<class 'int'>, {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('low', '</w>') {'low</w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'w i d est</w>': 3} 9 defaultdict(<class 'int'>, {('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('w', 'i') {'low</w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'wi d est</w>': 3}

2-1.build_bpe 함수를 완성해주세요.

[5]
Counter({'newest': 6, 'low': 5, 'widest': 3, 'lower': 2}) {'low': 5, 'lower': 2, 'newest': 6, 'widest': 3}
[51]

2-2. build_bpe 함수 평가

[52]
======Building BPE Vocab Test Case====== ['<pad>', '<unk>', '<cls>', '<sep>', '<msk>', 'abcde', 'abcd', 'abc', 'ab', 'c', 'a', 'e', 'b', 'd', '_'] The first test passed! The second test passed! ['<pad>', '<unk>', '<cls>', '<sep>', '<msk>', 'newest_', 'est_', 'low', 'new', 'est', 'lo', 'es', 'ne', 'o', 'w', 'l', 'e', 's', 'i', 't', 'r', 'n', 'd', '_'] The third test passed! ['<pad>', '<unk>', '<cls>', '<sep>', '<msk>', 'aaaaaaaa', 'aaaa', 'abab', 'aa', 'ab', 'a', 'b', '_'] The forth test passed! ['<pad>', '<unk>', '<cls>', '<sep>', '<msk>', 'abc_', 'bcd_', 'abc', 'bcd', 'bc', 'c', 'a', 'b', 'd', '_'] The fifth test passed! All 5 tests passed!



==================================

피어세션

복습 및 DACON 데이터 분석 및 전처리 해봄.



===================================

후기

피곤해



댓글 없음:

댓글 쓰기