[Paper Review] Generalization Through Memorization: Nearest Neighbor Language Models
Paper Review for kNN-LM model
0. Abstract ๐ฌ
kNN-LM
์ pre-train
๋ LM์ ์ ํ์ ์ผ๋ก kNN ์๊ณ ๋ฆฌ์ฆ๊ณผ ๊ฒฐํฉํ์ฌ ํ์ฅํ ๋ชจ๋ธ์ด๋ค. pre-trained LM
์ ์ด์ฉํด์ input ๋ฐ์ดํฐ๊ฐ latent space
๋ก ์๋ฒ ๋ฉ ๋๊ฒ ๋๋๋ฐ, ์ด latent space
์์ ๋ฒกํฐ ๊ฐ์ ๊ฑฐ๋ฆฌ๋ฅผ ํตํด์ ๊ฐ์ฅ ๊ฐ๊น์ด k ๊ฐ์ ํ๋ณด๋ฅผ ์ ํ๊ฒ ๋๋ค. ์ด๋ ์์์ ๋ฐ์ดํฐ์
(including training data) ์ ํตํด์ ๊ฐ๋ฅํ๋ค.
์ด ๋ฐฉ์์ Wikitext-103LM
์ ์ ์ฉํจ์ผ๋ก์จ, ์ด ๋
ผ๋ฌธ์์ ์๊ฐํ๋ ๋ชจ๋ธ์ SOTA๋ฅผ ๋ฌ์ฑํ์ผ๋ฉฐ, ์ถ๊ฐ์ ์ธ training ์์ด๋ 15.79์ perplexity
๋ก 2.9 point๋ ์ค์ด๋ ํจ๊ณผ๋ฅผ ๋ณด์๋ค. ๋ํ, ์ด ์ ๊ทผ๋ฒ์ ๋ ํฐ ํ๋ จ ๋ฐ์ดํฐ์
, ๊ทธ๋ฆฌ๊ณ ๋ค๋ฅธ domain
์ผ๋ก์ ์ ์ฉ ์ญ์ ํจ๊ณผ์ ์ผ๋ก ์ํํ ์ ์์๋ค.
์ง์ ์ผ๋ก๋, ์ด ๋ชจ๋ธ์ ์์ํ ํํ๋ค์ ๋ํด์ ๋์ฑ ํจ๊ณผ์ ์ธ ๋ชจ์ต์ ๋ณด์๊ณ , ํนํ factual knowledege
์ ๋ํด์ ํจ๊ณผ์ ์ด์๋ค. ๋์์ ์ด ์ฐ๊ตฌ๋ LM ์ ๊ทผ๋ณธ์ ์ธ task ์ธ next token prediction
๋ณด๋ค sequences ๊ฐ์ similarity
๋ฅผ ํ์ตํ๋ ๊ฒ์ด ๋ ํจ๊ณผ์ ์ธ ์ ๊ทผ ๋ฐฉ์์์ ์๋ฏธํ๊ธฐ๋ ํ๋ค.
1. Introduction โ๏ธ
Language Model ์ ์ผ๋ฐ์ ์ผ๋ก ์๋์ 2๊ฐ์ง task๋ฅผ ๋ชฉํ๋ก ํ๋ค.
1. ์ฆ ๋ฌธ์ฅ์ prefix๋ฅผ n์ฐจ์ ๋ฒกํฐ๋ก ๋ํ๋ธ๋ค. (์ ํํ๋ ๊ณ ์ ํฌ๊ธฐ์ representation)
2. ์ด๋ ๊ฒ ๋ง๋ค์ด์ง latent space์์์ ๊ฐ์ ์ด์ฉํด์ ๋ค์ ๋จ์ด๋ฅผ ์์ธกํ๋ค.
๋ณธ ๋
ผ๋ฌธ์์๋ ์ฒซ๋ฒ์งธ task๊ฐ ๋๋ฒ์งธ task ๋ณด๋ค ์ฌ์ด task๋ผ๋ ๊ฐ์ ํ์ ์ ๊ทผํ๋ค. ์๋ฅผ ๋ค์ด, โDickens is the author ofโ ๋ผ๋ ๋ฌธ์ฅ๊ณผ โDickens wroteโ ๋ผ๋ ๋ฌธ์ฅ์ ๋ณด์์ ๋, ๊ทธ ํ์ ์ฌ ๋จ์ด๋ฅผ ์์ธกํ์ง ๋ชปํ๋๋ผ๋ ๋ ๋ฌธ์ฅ์ด ๊ฐ์ ๋ป์ ๋ดํฌํ๊ณ ์์์ ๋๊ตฌ๋ ์ ์ ์๋ค. ์คํ์ ์ผ๋ก๋, prefix embedding
์ ๋ํด kNN์ ์ ์ฉ์ํจ ๊ฒฐ๊ณผ, ์ฑ๋ฅ์ด ํฅ์๋จ์ ํตํด LM์ด ์ฒซ๋ฒ์งธ task์ ๋ ํจ๊ณผ์ ์ด๋ผ๋ ๊ฐ๋ ฅํ ์ฆ๊ฑฐ๋ฅผ ์ ์ํ๋ค.
3-billion
๊ฐ์ token
์ ๋ชจ๋ธ์ ํ์ต ๋ฐ์ดํฐ๋ก ์ฌ์ฉํ๋ ๊ฒ๋ณด๋ค, 100-million
๊ฐ์ token
์ ์ด์ฉํด ํ์ตํ๊ณ 3-billion๊ฐ์ token
์ ๊ฐ์ง๋ dataset(documents)
์ ์ด ๋ชจ๋ธ์ ์ ์ฉํ๋ ๊ฒ์ด ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์๋ค. ์ด๋ ๊ณง Large dataset์ ์ฌ์ฉํ๋ LM์ ๋ํ ์๋ก์ด ๋ฐฉํฅ์ฑ์ ์ ์ํ๋ค. ๋น์ทํ๊ฒ, ๋จ์ํ datastore
์ ๋ค๋ฅธ domain์ ๋ฐ์ดํฐ๋ฅผ ์ฝ์
ํ๋ ๊ฒ๋ง์ผ๋ก๋ multiple domain
์์๋ ํจ๊ณผ์ ์ธ ์ฑ๋ฅ์ ๋ณด์๋ค.
๋ง์ง๋ง์ผ๋ก, ์ด ๋ชจ๋ธ์ ๋ช
์์ ๊ธฐ์ต์ ๋ํ ์ ๊ทผ (datastore์ด๋ผ๋ ๋ช
์์ ์ธ ๋ฐ์ดํฐ) ์ ํตํด long-tail patterns
(์๋ฅผ ๋ค์ด Factual Knowledge
) ์ ๋ํด์ ๋์ฑ ํจ๊ณผ์ ์ธ ๊ฒ์ ๋ฐ๊ฒฌํ๋ค.
2. Nearest Neighbor Language Modeling ๐ง
LM์ ๊ธฐ๋ณธ์ ์ผ๋ก sequence
์ ๋ํ ํ๋ฅ ์ ํ ๋นํ๋ค. ๋ค์ ๋งํด $c_t = (w_1, \cdots , w_{t-1})$ ๋ผ๋ context (sequence)
๊ฐ ์ฃผ์ด์ ธ ์์ ๋, LM (autoregressive ํ)์ $p(w_t|c_t)$๋ฅผ ๊ณ์ฐํด๋ธ๋ค.
kNN-LM์ pre-trained LM
์ ์ด์ฉํ์ฌ nearest-neighbors
๋ฅผ ๊ฒ์(retrieval
) ํ์ฌ augument
ํ๋ ๊ณผ์ ์ ๋ดํฌํ๋ฉฐ, ์ด datastore์๋ key-value
ํํ์ context-target
์๋ค์ด ์ ์ฅ๋์ด ์๋ค. (see Figure 1)
Datastore
$f()$๋ฅผ context $c$๋ฅผ fixed-length vector
๋ก mapping
ํด์ฃผ๋ ํจ์๋ผ๊ณ ๊ฐ์ ํด๋ณด์. ๋ง์ฝ์ i๋ฒ์งธ training data์ธ $(c_i, w_i) \in \mathcal{D}$๊ฐ ์ฃผ์ด์ก์ ๋, ์ฐ๋ฆฌ๋ datastore
์ ๋ค์๊ณผ ๊ฐ์ key-value
์งํฉ์ผ๋ก ์ ์ํ ์ ์๋ค.
Inference
์ด ๋ชจ๋ธ์ input context
$x$๊ฐ ์ฃผ์ด์ก์ ๋, $f(x)$๋ฅผ ํตํด $p_{LM}(y|x)$์ ๋ํ ํ๋ฅ ๋ถํฌ๋ฅผ ๊ณ์ฐํ๊ฒ ๋๋ค. ๋์์ datastore ์ $f(x)$๋ฅผ ์ด์ฉํด query๋ฅผ ๋ณด๋ด๊ฒ ๋๋๋ฐ, distance function
์ธ $d()$๋ฅผ ํตํด k-nearest neighbors
์ ํด๋นํ๋ ์งํฉ $\mathcal{N}$์ ์์ฑํ๋ค. (๋ณธ ๋
ผ๋ฌธ์์๋ distance function
์ $L^2$ distance๋ก ์ ์ํ๋ค.)
๊ทธ ํ์๋ softmax
์ negative distance
๋ฅผ ๋ฃ์์ผ๋ก์จ, ์๋์ ํ๋ฅ ์ ๊ณ์ฐํด๋ผ ์ ์๋ค. (๊ฑฐ๋ฆฌ๊ฐ ๊ฐ๊น์ธ์๋ก ๋์ ํ๋ฅ ์ ์ ํ๋๋ฅผ ๋ณด์ธ๋ค.)
์ด๋ฅผ ์ ํ์ ์ผ๋ก ๊ธฐ์กด LM์ ์ ์ฉํ๊ฒ ๋๋ฉด, $\lambda$ ๋ณ์๋ฅผ ์ด์ฉํด์ ๋ค์๊ณผ ๊ฐ์ด ์ต์ข
probability
๋ฅผ ์ ์ํ ์ ์๋ค.
Implementation
ํ ๊ฐ์ง ๋ฌธ์ ์ ์, Datastore
์ด billion
๊ฐ์ ๋ฐ์ดํฐ๋ฅผ ๋ดํฌํ๊ณ ์์ด computationally intensive
ํ๋ค๋ ๊ฒ์ด๋ค. ์ด๋ฅผ ๊ทน๋ณตํ๊ธฐ ์ํด FAISS
๋ผ๋ open source library
๋ฅผ ์ฌ์ฉํ์ฌ ๊ณ ์ฐจ์ ์์์์ kNN์ ํจ์จ์ ์ผ๋ก ๊ณ์ฐํ๊ฒ ๋๋ค. ์ถ๊ฐ์ ์ผ๋ก, $L^2$ ์ด์ธ์ inner product distance
๋ผ๋ distance function
๋ํ ์กด์ฌํ๋๋ฐ, ์ด ๋ชจ๋ธ์์๋ $L^2$ ๋ฐฉ์์ด ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์๋ค.
Related Cache Models
์ด์ ์ ๋น์ทํ ์ ๊ทผ ๋ฐฉ์์์๋, recent past์ ๋ํ caching
์ ํตํด์ ์ต๊ทผ์ ๋ฐ์ดํฐ์ ๋ํด ๋์ฑ ํจ๊ณผ์ ์ผ๋ก ๊ณ์ฐํ๋ ๋ฐฉ์ ๋ํ ์กด์ฌํ๋ค. ํ์ง๋ง ์ต๊ทผ์ ์ ๋ณด๋ฅผ copyํ ์ ์๋ self-attention
๊ธฐ๋ฒ์ ๊ฐ์ง Transformer
๋ชจ๋ธ์ด ๋ฑ์ฅํ๊ณ ๋์, ์ด ๋ฐฉ์์ ์ธ๊ธฐ๋ฅผ ์๊ฒ ๋๋ฉฐ, ์ป์ ์ ์๋ ์ด์ต ๋ํ ์ค์ด๋ค์๋ค. ๋ณธ ๋
ผ๋ฌธ์์๋ training data
์ ๋ํ ๋ช
์์ ๊ธฐ์ต์ ์ํด ์ค๋ก์ง training data
์ ๋ํด์๋ง caching
ํ๋ ๋ฐฉ์์ ํํ์ฌ ๋น์ทํ cases์ ๋ํ ํจ์จ์ ์ฆ๋์์ผฐ๋ค.
3. Experimental Setup ๐ฅฝ
Data
๋จ์ํ dataset์ ๋ํ ์ค๋ช ์ด๊ธฐ์ ๋ณ๋์ ์ค๋ช ์ ์๋ตํ๋ค.
Model Architecture
kNN-LM ๋ชจ๋ธ์ fixed-size context representations
์ ์์ฑํ๋ ๋ชจ๋ธ์ด๋ผ๋ฉด ๋ชจ๋ ํธํ์ด ๊ฐ๋ฅํ๋ค. ์ด ๋ชจ๋ธ์ ํ์ฌ (๋น์) SOTA๋ฅผ ๊ธฐ๋กํ๋ Decoder-only Transformer
์ ์ฌ์ฉํ๋ค. kNN-LM ๋ชจ๋ธ์ ๋ชจ๋ธ์ ๊ธฐ๋ณธ ๋ชจ๋ธ์ธ LM์ ๋ํ ํ๋ จ์ ์ํํ์ง ์๊ธฐ ๋๋ฌธ์, ๊ธฐ์กด ์ํคํ
์ณ์ ์ต์ ํ ๋ฐฉ์์ ๊ทธ๋๋ก ์ฌ์ฉํ๋ค.
16 layer
, each with 16 self-attention heads
, 1024 dimensional hidden states
, 4096 dimensional feedforward layers
, 247M parameters
๋ฑ์ ์ฌ์ฉํ๋ฉฐ ์ถ๊ฐ ์ ๋ณด๋ ์๋์ ๊ฐ๋ค.
Evaluation
์ด LM๋ค์ negative log-likelihood
๋ฅผ loss function
์ผ๋ก ์ฌ์ฉํ์ผ๋ฉฐ, ๋ชจ๋ธ์ ํ๊ฐ ๊ธฐ์ค์ผ๋ก์จ perplexity
๋ฅผ ์ฌ์ฉํ๋ค. (์ด์ง ์ฒจ์ธํ์๋ฉด, ์ต๊ทผ ๋
ผ๋ฌธ์ธ Mamba์ reject ์์ธ์ผ๋ก perplexity
๊ฐ ํ๊ฐ ๊ธฐ์ค์ด ๋๋ ๊ฒ์ ์ ๋นํ์ง ์๋ค๋ ๊ธ์ ๋ณธ ๊ฒ ๊ฐ์๋ฐ, ์ด๋๋ ๊ธฐ์ค์ด ์กฐ๊ธ ๋ฌ๋๋๋ณด๋ค.)
kNN-LM
Computational Cost
์ถ๊ฐ์ ์ธ Training ์์ด๋ ๋ชจ๋ธ์ ๊ตฌํํ ์ ์์ง๋ง, key-value
ํํ์ datastore
์ ์์ฑํ๊ธฐ ์ํ์ฌ 1 epoch ์ ๋์ ์๊ฐ์ด ์์๋๋ค. ๋ํ key๋ค์ด ์ ์ฅ๋ ํ WIKI-103M
์ ์บ์๋ฅผ CPU์ ์ ์ฉํ๋ ๋ฐ์ 2์๊ฐ์ด ์์๋๋ฉฐ, 1024๊ฐ์ NN์ ๊ตฌํ๋ ๋ฐ์ 25๋ถ ์ ๋๊ฐ ์์๋๋ค. ๋ฌผ๋ก ๋ฐ์ดํฐ์
์ ํฌ๊ธฐ์ ์ ํ์ ์ผ๋ก ๋น๋กํ์ฌ ์๊ฐ์ด ์ฆ๊ฐํ์ง๋ง, ์ด๋ ์ฝ๊ฒ ๋ณ๋ ฌํ๊ฐ ๊ฐ๋ฅํ๋ฉฐ, GPU์ ์ฌ์ฉ์ด ํ์ํ์ง ์๋ค.
4. Experiemtents ๐ฌ
4.1 Using the Training data as the Datastore
๊ธฐ์กด์ SOTA ๋ฐฉ์๊ณผ ๋น๊ตํ์ฌ, ๋ณธ ๋
ผ๋ฌธ์ kNN-LM ์ด ์ผ๋ง๋ ๋์ ์ฑ๋ฅ์ ๋ณด์๋์ง๋ฅผ ์คํํด๋ณด์๋ค. ์ฌ๊ธฐ์ Training data ๊ทธ๋๋ก Datastore
์ ๋์
ํ๋ค. ๊ธฐ์กด์ SOTA์ ๋น๊ตํ์ฌ 18.65 ์์ 16.12๋ก ์๋ก์ด SOTA๋ฅผ ๋ฌ์ฑํ๋ค.
์ถ๊ฐ์ ์ผ๋ก, WIKI
๊ฐ caching์ ์ ๋
์ข์ ๊ฒฝ์ฐ๋ฅผ ๋๋นํ์ฌ BOOKS corpus
๋ฅผ ์ด์ฉํ์ฌ ๊ฐ์ ์คํ์ ๋ฐ๋ณตํด ๋ณด์๋ค. ๊ทธ ๊ฒฐ๊ณผ๋ ์๋์ ๊ฐ๋ค.
4.2 More data without Training
์ด๋ฒ์๋ Training dataset
๊ณผ Datastore
์ ๋ถ๋ฆฌํ์ฌ, ์๋ก ๋ค๋ฅธ ๋ฐ์ดํฐ์
์ ์ด์ฉํ์ ๋์๋ ํจ๊ณผ๊ฐ ์๋์ง๋ฅผ ํ์ธํด๋ณด์๋ค.
์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด, WIKI-3B
์ WIKI-100M
dataset์ผ๋ก ์คํํด ๋ณด์๋๋ฐ, ๋น์ฐํ๊ฒ๋ ๋ ํฐ ๋ฐ์ดํฐ์
์ธ WIKI-3B
์ ์ด์ฉํด ํ์ตํ์ ๋ ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์๋ค. ํ์ง๋ง, WIKI-100M
์ผ๋ก ํ์ตํ ๋ค์ WIKI-3B
๋ฅผ datastore
๋ก ํ์ฉํ์ ๋, ๊ทธ ์ฑ๋ฅ์ด ๊ธฐ์กด์ LM์ ๋ฅ๊ฐํ๋ ๊ฒ์ผ๋ก ๋ณด์, kNN-LM ๋ฐฉ์์ด ๋์ฑ ํจ์จ์ ์ด๊ณ ์ ํ๋๊ฐ ๋๋ค๋ ๊ฒ์ ํ์ธํ ์ ์์๋ค.
๋ํ, ์์ ๊ฐ์ด kNN-LM์ datastore
ํฌ๊ธฐ์ ๋ํด์๋ ์คํ์ ํด ๋ณด์๋๋ฐ, 1.6B์ dataset๋ง ์ฌ์ฉํ์ ๋ ์ด๋ฏธ Vanilla LM์ ์ฑ๋ฅ์ ๋ฅ๊ฐํ์๊ณ , 3B์ ๊ฒฝ์ฐ์๋ ๊ทธ๋ฌํ๋ค. ๋ ๋์๊ฐ 3B์ ๊ฒฝ์ฐ์๋ perplexity
์ ๊ฐ์๋๊ฐ saturated
(ํฌํ) ๋์ง ์๋ ๊ฒ์ ๋ณด์, ๋ ํฐ ์ ์ฌ์ฑ์ด ์กด์ฌํ๋ค.
๋ง์ฐฌ๊ฐ์ง๋ก, ๊ฐ datastore
ํฌ๊ธฐ์ ๋ํด์ optimal
ํ $\lambda$๋ฅผ ๊ตฌํด ๋ณด์์ ๋, ์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ๊ฒฐ๊ณผ๊ฐ ๋ํ๋ฌ๋ค.
4.3 Domain Adaptation
Domain Adaptation
์คํ์ ์ํ์ฌ, WIKI-3B
๋ก ํ์ต๋ ๋ชจ๋ธ์ BOOK ์ dataset์ผ๋ก inference ํด๋ณด์๋ค. ๊ทธ ๊ฒฐ๊ณผ๋ ์๋์ ๊ฐ๋ค.
์์ํ๊ฒ datastore
์์ด ์ถ๋ก ํ ๊ฒฐ๊ณผ ๋งค์ฐ ๋ฎ์ ์ฑ๋ฅ์ ๋ณด์์ง๋ง, BOOKS
๋ฅผ datastore
๋ก ํ์ฉํ๊ฒ ๋๋ฉด, perplexity๊ฐ 15 ๊ฐ๊น์ด ๋จ์ด์ง๋ ๊ฒ์ ํ์ธํ ์ ์๋ค. ์ฆ target domain์ ๋ํ datastore
์ด ์๋ค๋ฉด, ์ถฉ๋ถํ ๋ค๋ฅธ domain์ผ๋ก ์ ์ฉ์ด ๊ฐ๋ฅํจ์ ์ ์ ์๋ค.
5. Tuninig Nearest Neighbor Search ๐งช
Key Function
Similarity Search
๋ฅผ ์ํ์ฌ, prior context
๋ฅผ ํจ๊ณผ์ ์ผ๋ก fixed-size representation
์ผ๋ก ๋ณํํ๋ $f()$๋ ์ค์ํ ์์์ด๋ค. ์ด๋ฅผ ์คํํด๋ณด๊ธฐ ์ํด, Transformer ๊ตฌ์กฐ์ ๋ค์ํ ๋ถ๋ถ์ ํ๋ณด๋ก ์คํ์ ์งํํ์๋ค. (๋ชจ๋ Transformer์ ๋ง์ง๋ง layer ๋ถ๋ถ์ด๋ค.)
์คํ์์ ๋ณด๋ ๊ฒ๊ณผ ๊ฐ์ด, FFN input after layer norm
๋ถ๋ถ์ด ๊ฐ์ฅ ๋์ ์ฑ๋ฅ์ ๋ํ๋ด์๊ณ , ์ถ๊ฐ์ ์ผ๋ก ๋ง์ง๋ง ์ง์ ์ layer (second-last) ์๋ ์คํ์ ํด๋ณด์์ผ๋, ๋น์ทํ ๊ฒฝํฅ์ ์ ์์ง๋ง ์ด์ง ๋ฎ์ ์ฑ๋ฅ์ ๋ณด์๋ค.
์ด๋ฅผ ํตํด FFN์ ๋ค์ ํ ํฐ์ ์์ธกํ๋ ๊ฒ์, MHSA๋ representation ์ ๋์ฑ ํจ๊ณผ์ ์ธ ๊ฒ์ด๋ผ๊ณ ์ถ์ธกํด๋ณผ ์ ์๋ค.
Other elements (Number of Neighbors, Interpolation Param, etc)
๊ทธ๋ฆผ์์ ๋ณด๋ค์ํผ, k-NN ์์์ k ๋ฅผ ํ๋์ ์์๋ก, interpolation
๋ณ์๋ฅผ ํ๋์ ์์๋ก ํ์ฌ ์คํ์ ์งํํ๋ค. ๊ทธ ๊ฒฐ๊ณผ k ๊ฐ ๋์ด๋ ์๋ก perplexity
๊ฐ ๊ฐ์ํ๋ฉฐ, ๊ฐ ์ํฉ (In-domain or Domain Adaptation
)์ ๋ง์ถ์ด ์ ์ ํจ $\lambda$ ๊ฐ์ด ํ์ฑ๋๋ ๊ฒ์ ๋ณผ ์ ์๋ค.
๋ํ, Similarity Function
์ precision
์ ๋ํด์๋ ์คํ์ ์งํํ๋๋ฐ, $L^2$ distance์ ๋ํด์ full precision
์ ํตํด ์ฐ์ฐํจ์ผ๋ก์จ perplexity ๊ฐ 16.5์์ 16.06์ผ๋ก ๊ฐ์ํ๋ ์ฑ๋ฅ์ ๋ณด์๋ค.
6. Analysis ๐ง
Qualitative Analysis
$p_{kNN}$ ์ด ์ $p_{LM}$ ๋ณด๋ค ๋์ ์ฑ๋ฅ์ ๊ฐ์ง๋์ง ์ดํดํ๊ธฐ ์ํ์ฌ $p_{kNN}$ ์ด ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์๋ ์์๋ค์ ์ดํด๋ณด๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
์์์์ ๋ณด๋ค์ํผ, kNN-LM model์ rare patterns
์ ๋ํด์ ๊ฐ์ฅ ๋์ ์ฑ๋ฅ์ ๋ณด์์ผ๋ฉฐ, ์ด๋ ๊ณง factual knowledge
๋ฅผ ๋ปํ๋ค. ํนํ training set์ ์กด์ฌํ๊ฑฐ๋ ๋น์ทํ context
์ ๊ฒฝ์ฐ ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์๋ค.
์ด๋ฅผ ํตํด parameters๋ฅผ ํตํด ์ง์์ implicit
ํ๊ฒ ํ์ตํ๋ ๊ฒ๋ณด๋ค explicit
ํ๊ฒ, ์ฆ Nearest Neighbor
์ ์ฐพ์ ํ์ตํ๋ ๊ณผ์ ์ด ๋์ฑ ํจ์จ์ ์ด๋ผ๊ณ ๋ณผ ์ ์๋ค.
Simple vs Neural Representation
ํ์ง๋ง, ์ด๋ฐ rare patterns
(long-tail phenomena)๋ ๋จ์ํ N-gram model
์์๋ ์ถฉ๋ถํ ์ฑ๋ฅ์ ๊ฐ์ง ์๋ ์๋ค. ๋ฐ๋ผ์ n-gram
๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ํ์ต์ ์์ผ๋ณด์๋๋ฐ, ๊ฒฐ๊ณผ๋ ์์ ๊ฐ์ด n-gram model
์ ํ์ ํ ๋ฎ์ ์ ํ๋๋ฅผ ๋ํ๋ด์๋ค. ์ด๋ฅผ ๋ฏธ๋ฃจ์ด ๋ณด๋ฉด, kNN-LM์ ๋จ์ํ local context๋ฅผ
ํ์ตํ๋ ๊ฒ ์ด์์ผ๋ก, global context
๋ฅผ ํ์ตํ๋ค๊ณ ์๊ฐํ ์ ์๋ค.
Implicit vs Explicit Memory
๊ทธ๋ ๋ค๋ฉด ๊ณผ์ฐ Explicit Memory ๊ฐ Implicit Memory ๋ณด๋ค ํจ๊ณผ์ ์ผ๊น? ๋ณธ ๋ ผ๋ฌธ์์์ datastore ์ LM ์ด ๋ชจ๋ ์ธ์ฐ๋ ๊ฒ์ด ๊ฐ๋ฅํ ๊น?
์ด๋ฅผ ์คํํด๋ณด๊ธฐ ์ํ์ฌ, Transformer ์ dropout
์์ด ํ์ต์์ผ, datastore
์ ๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ ํ์ตํ๋๋ก ํ๋ค. ๊ทธ ๊ฒฐ๊ณผ ์์ ๊ทธ๋ฆผ ๊ฐ์ด, loss ๊ฐ ์์ 0์ผ๋ก ๋จ์ด์ง ๊ฒ์ ํ์ธํ ์ ์๊ณ , ์ด๋ ๊ณง datastore
์ ๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ ์ ํํ๊ฒ ํ์ตํ๋ค๋ ์๋ฏธ๊ฐ ๋๋ค.
๋ฐ๋ผ์ ์ด๋ ๊ฒ overfitting
๋ LM๊ณผ explicit memory
์ธ datastore
๋ฅผ ๋น๊ตํ๊ธฐ ์ํ์ฌ ๊ฐ๊ฐ์ original LM ์ interpolate
ํ์ฌ perplexity
๋ฅผ ์ธก์ ํ๋๋ฐ, LM์ ๊ฒฐ๊ณผ 0.1 ํฅ์, datastore
์ ๊ฒฐ๊ณผ 1.9 ํฅ์์ผ๋ก explicit memory
๊ฐ ๋ ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ณด์๋ค.
์ด ์คํ์ ๊ฒฐ๊ณผ๋ก, Transformer LM์ datastore์ ๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ ์ธ์ธ ์ ๋๋ก ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ณด์์ง๋ง, ์ด ๊ฒฝ์ฐ generalize ์ฑ๋ฅ์ด ๋จ์ด์ง๋ค๋ ๊ฒ์ ์ถ์ธกํด๋ณผ ์ ์๋ค.
8. Conclusion & Future Work ๐ฌ
Related Work Section ์ ์ค์ํ ๋ถ๋ถ์ด ์๋๋ผ๊ณ ํ๋จ๋์ด ์ ์ธํ์์ต๋๋ค.
์ด๋ ๋ฏ kNN-LM model์ ๊ธฐ์กด์ standard LM์ ๋ฅ๊ฐํ๋ ์ฑ๋ฅ์ ๋ณด์๋ค. ์ด๋ฐ ์ ๊ทผ ๋ฐฉ์์ ์์์ NLP task
์ ์ ์ฉ๋ ์ ์๋ค. ์ด ์ ๊ทผ ๋ฐฉ์์ ์ฑ๊ณต์ ์ธํ์ฌ context
๊ฐ์ similarity
๋ฅผ ํ์ตํ๋ ๊ฒ์ด ๋ค์ ํ ํฐ์ ํ์ตํ๋ ๊ฒ๋ณด๋ค ์ฌ์ด task์์ ๋ปํ๊ธฐ๋ ํ๋ค. ์ถํ์๋ ์ด๋ฐ simlarity function
์ ๋ช
์์ ์ผ๋ก ํ์ตํ๊ฑฐ๋, datastore
์ ํฌ๊ธฐ๋ฅผ ์ค์ด๋ ๊ฒ ๋ฑ์ ์ฐ๊ตฌ๊ฐ ํ์ํ ๊ฒ์ด๋ค.