Post

[Paper Review] Generalization Through Memorization: Nearest Neighbor Language Models

Paper Review for kNN-LM model

[Paper Review] Generalization Through Memorization: Nearest Neighbor Language Models
Generalization Through Memorization: Nearest Neighbor Language Models
arxiv pdf link for kNN-LM

0. Abstract ๐ŸŽฌ

kNN-LM ์€ pre-train ๋œ LM์„ ์„ ํ˜•์ ์œผ๋กœ kNN ์•Œ๊ณ ๋ฆฌ์ฆ˜๊ณผ ๊ฒฐํ•ฉํ•˜์—ฌ ํ™•์žฅํ•œ ๋ชจ๋ธ์ด๋‹ค. pre-trained LM์„ ์ด์šฉํ•ด์„œ input ๋ฐ์ดํ„ฐ๊ฐ€ latent space ๋กœ ์ž„๋ฒ ๋”ฉ ๋˜๊ฒŒ ๋˜๋Š”๋ฐ, ์ด latent space ์ƒ์˜ ๋ฒกํ„ฐ ๊ฐ„์˜ ๊ฑฐ๋ฆฌ๋ฅผ ํ†ตํ•ด์„œ ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด k ๊ฐœ์˜ ํ›„๋ณด๋ฅผ ์ •ํ•˜๊ฒŒ ๋œ๋‹ค. ์ด๋Š” ์ž„์˜์˜ ๋ฐ์ดํ„ฐ์…‹ (including training data) ์„ ํ†ตํ•ด์„œ ๊ฐ€๋Šฅํ•˜๋‹ค.

์ด ๋ฐฉ์‹์„ Wikitext-103LM ์— ์ ์šฉํ•จ์œผ๋กœ์จ, ์ด ๋…ผ๋ฌธ์—์„œ ์†Œ๊ฐœํ•˜๋Š” ๋ชจ๋ธ์€ SOTA๋ฅผ ๋‹ฌ์„ฑํ–ˆ์œผ๋ฉฐ, ์ถ”๊ฐ€์ ์ธ training ์—†์ด๋„ 15.79์˜ perplexity๋กœ 2.9 point๋‚˜ ์ค„์ด๋Š” ํšจ๊ณผ๋ฅผ ๋ณด์˜€๋‹ค. ๋˜ํ•œ, ์ด ์ ‘๊ทผ๋ฒ•์€ ๋” ํฐ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์…‹, ๊ทธ๋ฆฌ๊ณ  ๋‹ค๋ฅธ domain ์œผ๋กœ์˜ ์ ์šฉ ์—ญ์‹œ ํšจ๊ณผ์ ์œผ๋กœ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์—ˆ๋‹ค.

์งˆ์ ์œผ๋กœ๋Š”, ์ด ๋ชจ๋ธ์€ ์ƒ์†Œํ•œ ํ‘œํ˜„๋“ค์— ๋Œ€ํ•ด์„œ ๋”์šฑ ํšจ๊ณผ์ ์ธ ๋ชจ์Šต์„ ๋ณด์˜€๊ณ , ํŠนํžˆ factual knowledege ์— ๋Œ€ํ•ด์„œ ํšจ๊ณผ์ ์ด์˜€๋‹ค. ๋™์‹œ์— ์ด ์—ฐ๊ตฌ๋Š” LM ์˜ ๊ทผ๋ณธ์ ์ธ task ์ธ next token prediction ๋ณด๋‹ค sequences ๊ฐ„์˜ similarity ๋ฅผ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์ด ๋” ํšจ๊ณผ์ ์ธ ์ ‘๊ทผ ๋ฐฉ์‹์ž„์„ ์˜๋ฏธํ•˜๊ธฐ๋„ ํ•œ๋‹ค.

1. Introduction โ˜•๏ธ

Language Model ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ์•„๋ž˜์˜ 2๊ฐ€์ง€ task๋ฅผ ๋ชฉํ‘œ๋กœ ํ•œ๋‹ค.

1. ์ฆ‰ ๋ฌธ์žฅ์˜ prefix๋ฅผ n์ฐจ์› ๋ฒกํ„ฐ๋กœ ๋‚˜ํƒ€๋‚ธ๋‹ค. (์ •ํ™•ํžˆ๋Š” ๊ณ ์ • ํฌ๊ธฐ์˜ representation)

2. ์ด๋ ‡๊ฒŒ ๋งŒ๋“ค์–ด์ง„ latent space์—์„œ์˜ ๊ฐ’์„ ์ด์šฉํ•ด์„œ ๋‹ค์Œ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•œ๋‹ค.

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ์ฒซ๋ฒˆ์งธ task๊ฐ€ ๋‘๋ฒˆ์งธ task ๋ณด๋‹ค ์‰ฌ์šด task๋ผ๋Š” ๊ฐ€์ • ํ•˜์— ์ ‘๊ทผํ–ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, โ€œDickens is the author ofโ€ ๋ผ๋Š” ๋ฌธ์žฅ๊ณผ โ€œDickens wroteโ€ ๋ผ๋Š” ๋ฌธ์žฅ์„ ๋ณด์•˜์„ ๋•Œ, ๊ทธ ํ›„์— ์˜ฌ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•˜์ง€ ๋ชปํ•˜๋”๋ผ๋„ ๋‘ ๋ฌธ์žฅ์ด ๊ฐ™์€ ๋œป์„ ๋‚ดํฌํ•˜๊ณ  ์žˆ์Œ์€ ๋ˆ„๊ตฌ๋‚˜ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ์‹คํ—˜์ ์œผ๋กœ๋„, prefix embedding์— ๋Œ€ํ•ด kNN์„ ์ ์šฉ์‹œํ‚จ ๊ฒฐ๊ณผ, ์„ฑ๋Šฅ์ด ํ–ฅ์ƒ๋จ์„ ํ†ตํ•ด LM์ด ์ฒซ๋ฒˆ์งธ task์— ๋” ํšจ๊ณผ์ ์ด๋ผ๋Š” ๊ฐ•๋ ฅํ•œ ์ฆ๊ฑฐ๋ฅผ ์ œ์‹œํ•œ๋‹ค.

3-billion๊ฐœ์˜ token์„ ๋ชจ๋ธ์˜ ํ•™์Šต ๋ฐ์ดํ„ฐ๋กœ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค, 100-million๊ฐœ์˜ token์„ ์ด์šฉํ•ด ํ•™์Šตํ•˜๊ณ  3-billion๊ฐœ์˜ token์„ ๊ฐ€์ง€๋Š” dataset(documents)์„ ์ด ๋ชจ๋ธ์— ์ ์šฉํ•˜๋Š” ๊ฒƒ์ด ๋” ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค. ์ด๋Š” ๊ณง Large dataset์„ ์‚ฌ์šฉํ•˜๋Š” LM์— ๋Œ€ํ•œ ์ƒˆ๋กœ์šด ๋ฐฉํ–ฅ์„ฑ์„ ์ œ์‹œํ•œ๋‹ค. ๋น„์Šทํ•˜๊ฒŒ, ๋‹จ์ˆœํžˆ datastore ์— ๋‹ค๋ฅธ domain์˜ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฝ์ž…ํ•˜๋Š” ๊ฒƒ๋งŒ์œผ๋กœ๋„ multiple domain์—์„œ๋„ ํšจ๊ณผ์ ์ธ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค.

image

๋งˆ์ง€๋ง‰์œผ๋กœ, ์ด ๋ชจ๋ธ์€ ๋ช…์‹œ์  ๊ธฐ์–ต์— ๋Œ€ํ•œ ์ ‘๊ทผ (datastore์ด๋ผ๋Š” ๋ช…์‹œ์ ์ธ ๋ฐ์ดํ„ฐ) ์„ ํ†ตํ•ด long-tail patterns (์˜ˆ๋ฅผ ๋“ค์–ด Factual Knowledge) ์— ๋Œ€ํ•ด์„œ ๋”์šฑ ํšจ๊ณผ์ ์ธ ๊ฒƒ์„ ๋ฐœ๊ฒฌํ–ˆ๋‹ค.

2. Nearest Neighbor Language Modeling ๐Ÿง

LM์€ ๊ธฐ๋ณธ์ ์œผ๋กœ sequence์— ๋Œ€ํ•œ ํ™•๋ฅ ์„ ํ• ๋‹นํ•œ๋‹ค. ๋‹ค์‹œ ๋งํ•ด $c_t = (w_1, \cdots , w_{t-1})$ ๋ผ๋Š” context (sequence)๊ฐ€ ์ฃผ์–ด์ ธ ์žˆ์„ ๋•Œ, LM (autoregressive ํ•œ)์€ $p(w_t|c_t)$๋ฅผ ๊ณ„์‚ฐํ•ด๋‚ธ๋‹ค.

kNN-LM์€ pre-trained LM์„ ์ด์šฉํ•˜์—ฌ nearest-neighbors๋ฅผ ๊ฒ€์ƒ‰(retrieval) ํ•˜์—ฌ augument ํ•˜๋Š” ๊ณผ์ •์„ ๋‚ดํฌํ•˜๋ฉฐ, ์ด datastore์—๋Š” key-value ํ˜•ํƒœ์˜ context-target ์Œ๋“ค์ด ์ €์žฅ๋˜์–ด ์žˆ๋‹ค. (see Figure 1)

Datastore

$f()$๋ฅผ context $c$๋ฅผ fixed-length vector๋กœ mappingํ•ด์ฃผ๋Š” ํ•จ์ˆ˜๋ผ๊ณ  ๊ฐ€์ •ํ•ด๋ณด์ž. ๋งŒ์•ฝ์— i๋ฒˆ์งธ training data์ธ $(c_i, w_i) \in \mathcal{D}$๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ, ์šฐ๋ฆฌ๋Š” datastore์„ ๋‹ค์Œ๊ณผ ๊ฐ™์€ key-value ์ง‘ํ•ฉ์œผ๋กœ ์ •์˜ํ•  ์ˆ˜ ์žˆ๋‹ค.

\[(\mathcal{K}, \mathcal{V}) = \{(f(c_i), w_i)|(c_i, w_i)\} \in \mathcal{D}\]

Inference

์ด ๋ชจ๋ธ์€ input context $x$๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ, $f(x)$๋ฅผ ํ†ตํ•ด $p_{LM}(y|x)$์— ๋Œ€ํ•œ ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ๊ณ„์‚ฐํ•˜๊ฒŒ ๋œ๋‹ค. ๋™์‹œ์— datastore ์— $f(x)$๋ฅผ ์ด์šฉํ•ด query๋ฅผ ๋ณด๋‚ด๊ฒŒ ๋˜๋Š”๋ฐ, distance function์ธ $d()$๋ฅผ ํ†ตํ•ด k-nearest neighbors์— ํ•ด๋‹นํ•˜๋Š” ์ง‘ํ•ฉ $\mathcal{N}$์„ ์ƒ์„ฑํ•œ๋‹ค. (๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” distance function์„ $L^2$ distance๋กœ ์ •์˜ํ–ˆ๋‹ค.)

๊ทธ ํ›„์—๋Š” softmax์— negative distance๋ฅผ ๋„ฃ์Œ์œผ๋กœ์จ, ์•„๋ž˜์˜ ํ™•๋ฅ ์„ ๊ณ„์‚ฐํ•ด๋‚ผ ์ˆ˜ ์žˆ๋‹ค. (๊ฑฐ๋ฆฌ๊ฐ€ ๊ฐ€๊นŒ์šธ์ˆ˜๋ก ๋†’์€ ํ™•๋ฅ ์˜ ์ •ํ™•๋„๋ฅผ ๋ณด์ธ๋‹ค.)

\[p*{kNN}(y|x) \propto \sum_{(k*i, v_i) \in \mathcal{N}}{\mathbb{1}_{y=v_i} \, exp(-d(k_i, f(x)))}\]

์ด๋ฅผ ์„ ํ˜•์ ์œผ๋กœ ๊ธฐ์กด LM์— ์ ์šฉํ•˜๊ฒŒ ๋˜๋ฉด, $\lambda$ ๋ณ€์ˆ˜๋ฅผ ์ด์šฉํ•ด์„œ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ตœ์ข… probability๋ฅผ ์ •์˜ํ•  ์ˆ˜ ์žˆ๋‹ค.

\[p(y|x) = \lambda * p_{kNN}(y|x) + (1 - \lambda) * p_{LM}(y|x)\]

Implementation

ํ•œ ๊ฐ€์ง€ ๋ฌธ์ œ์ ์€, Datastore ์ด billion ๊ฐœ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ๋‚ดํฌํ•˜๊ณ  ์žˆ์–ด computationally intensive ํ•˜๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์ด๋ฅผ ๊ทน๋ณตํ•˜๊ธฐ ์œ„ํ•ด FAISS๋ผ๋Š” open source library ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ณ ์ฐจ์› ์ƒ์—์„œ์˜ kNN์„ ํšจ์œจ์ ์œผ๋กœ ๊ณ„์‚ฐํ•˜๊ฒŒ ๋œ๋‹ค. ์ถ”๊ฐ€์ ์œผ๋กœ, $L^2$ ์ด์™ธ์— inner product distance ๋ผ๋Š” distance function ๋˜ํ•œ ์กด์žฌํ•˜๋Š”๋ฐ, ์ด ๋ชจ๋ธ์—์„œ๋Š” $L^2$ ๋ฐฉ์‹์ด ๋” ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค.

์ด์ „์˜ ๋น„์Šทํ•œ ์ ‘๊ทผ ๋ฐฉ์‹์—์„œ๋Š”, recent past์— ๋Œ€ํ•œ caching ์„ ํ†ตํ•ด์„œ ์ตœ๊ทผ์˜ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด ๋”์šฑ ํšจ๊ณผ์ ์œผ๋กœ ๊ณ„์‚ฐํ•˜๋Š” ๋ฐฉ์‹ ๋˜ํ•œ ์กด์žฌํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ตœ๊ทผ์˜ ์ •๋ณด๋ฅผ copyํ•  ์ˆ˜ ์žˆ๋Š” self-attention ๊ธฐ๋ฒ•์„ ๊ฐ€์ง„ Transformer ๋ชจ๋ธ์ด ๋“ฑ์žฅํ•˜๊ณ  ๋‚˜์„œ, ์ด ๋ฐฉ์‹์€ ์ธ๊ธฐ๋ฅผ ์žƒ๊ฒŒ ๋˜๋ฉฐ, ์–ป์„ ์ˆ˜ ์žˆ๋Š” ์ด์ต ๋˜ํ•œ ์ค„์–ด๋“ค์—ˆ๋‹ค. ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” training data์— ๋Œ€ํ•œ ๋ช…์‹œ์  ๊ธฐ์–ต์„ ์œ„ํ•ด ์˜ค๋กœ์ง€ training data์— ๋Œ€ํ•ด์„œ๋งŒ cachingํ•˜๋Š” ๋ฐฉ์‹์„ ํƒํ•˜์—ฌ ๋น„์Šทํ•œ cases์— ๋Œ€ํ•œ ํšจ์œจ์„ ์ฆ๋Œ€์‹œ์ผฐ๋‹ค.

3. Experimental Setup ๐Ÿฅฝ

Data

๋‹จ์ˆœํžˆ dataset์— ๋Œ€ํ•œ ์„ค๋ช…์ด๊ธฐ์— ๋ณ„๋„์˜ ์„ค๋ช…์€ ์ƒ๋žตํ•œ๋‹ค.

image

Model Architecture

kNN-LM ๋ชจ๋ธ์€ fixed-size context representations์„ ์ƒ์„ฑํ•˜๋Š” ๋ชจ๋ธ์ด๋ผ๋ฉด ๋ชจ๋‘ ํ˜ธํ™˜์ด ๊ฐ€๋Šฅํ•˜๋‹ค. ์ด ๋ชจ๋ธ์€ ํ˜„์žฌ (๋‹น์‹œ) SOTA๋ฅผ ๊ธฐ๋กํ–ˆ๋˜ Decoder-only Transformer์„ ์‚ฌ์šฉํ•œ๋‹ค. kNN-LM ๋ชจ๋ธ์€ ๋ชจ๋ธ์˜ ๊ธฐ๋ณธ ๋ชจ๋ธ์ธ LM์— ๋Œ€ํ•œ ํ›ˆ๋ จ์„ ์‹œํ–‰ํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์—, ๊ธฐ์กด ์•„ํ‚คํ…์ณ์™€ ์ตœ์ ํ™” ๋ฐฉ์‹์„ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ–ˆ๋‹ค.

16 layer, each with 16 self-attention heads, 1024 dimensional hidden states, 4096 dimensional feedforward layers, 247M parameters ๋“ฑ์„ ์‚ฌ์šฉํ•˜๋ฉฐ ์ถ”๊ฐ€ ์ •๋ณด๋Š” ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

image

Evaluation

์ด LM๋“ค์€ negative log-likelihood ๋ฅผ loss function์œผ๋กœ ์‚ฌ์šฉํ–ˆ์œผ๋ฉฐ, ๋ชจ๋ธ์˜ ํ‰๊ฐ€ ๊ธฐ์ค€์œผ๋กœ์จ perplexity๋ฅผ ์‚ฌ์šฉํ–ˆ๋‹ค. (์‚ด์ง ์ฒจ์–ธํ•˜์ž๋ฉด, ์ตœ๊ทผ ๋…ผ๋ฌธ์ธ Mamba์˜ reject ์›์ธ์œผ๋กœ perplexity๊ฐ€ ํ‰๊ฐ€ ๊ธฐ์ค€์ด ๋˜๋Š” ๊ฒƒ์€ ์ •๋‹นํ•˜์ง€ ์•Š๋‹ค๋Š” ๊ธ€์„ ๋ณธ ๊ฒƒ ๊ฐ™์€๋ฐ, ์ด๋•Œ๋Š” ๊ธฐ์ค€์ด ์กฐ๊ธˆ ๋‹ฌ๋ž๋‚˜๋ณด๋‹ค.)

kNN-LM

image image

Computational Cost

์ถ”๊ฐ€์ ์ธ Training ์—†์ด๋„ ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, key-value ํ˜•ํƒœ์˜ datastore ์„ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•˜์—ฌ 1 epoch ์ •๋„์˜ ์‹œ๊ฐ„์ด ์†Œ์š”๋œ๋‹ค. ๋˜ํ•œ key๋“ค์ด ์ €์žฅ๋œ ํ›„ WIKI-103M์˜ ์บ์‹œ๋ฅผ CPU์— ์ ์šฉํ•˜๋Š” ๋ฐ์— 2์‹œ๊ฐ„์ด ์†Œ์š”๋˜๋ฉฐ, 1024๊ฐœ์˜ NN์„ ๊ตฌํ•˜๋Š” ๋ฐ์— 25๋ถ„ ์ •๋„๊ฐ€ ์†Œ์š”๋œ๋‹ค. ๋ฌผ๋ก  ๋ฐ์ดํ„ฐ์…‹์˜ ํฌ๊ธฐ์— ์„ ํ˜•์ ์œผ๋กœ ๋น„๋ก€ํ•˜์—ฌ ์‹œ๊ฐ„์ด ์ฆ๊ฐ€ํ•˜์ง€๋งŒ, ์ด๋Š” ์‰ฝ๊ฒŒ ๋ณ‘๋ ฌํ™”๊ฐ€ ๊ฐ€๋Šฅํ•˜๋ฉฐ, GPU์˜ ์‚ฌ์šฉ์ด ํ•„์š”ํ•˜์ง€ ์•Š๋‹ค.

4. Experiemtents ๐Ÿ”ฌ

4.1 Using the Training data as the Datastore

image

๊ธฐ์กด์˜ SOTA ๋ฐฉ์‹๊ณผ ๋น„๊ตํ•˜์—ฌ, ๋ณธ ๋…ผ๋ฌธ์˜ kNN-LM ์ด ์–ผ๋งˆ๋‚˜ ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋Š”์ง€๋ฅผ ์‹คํ—˜ํ•ด๋ณด์•˜๋‹ค. ์—ฌ๊ธฐ์„œ Training data ๊ทธ๋Œ€๋กœ Datastore์— ๋Œ€์ž…ํ–ˆ๋‹ค. ๊ธฐ์กด์˜ SOTA์™€ ๋น„๊ตํ•˜์—ฌ 18.65 ์—์„œ 16.12๋กœ ์ƒˆ๋กœ์šด SOTA๋ฅผ ๋‹ฌ์„ฑํ–ˆ๋‹ค.

์ถ”๊ฐ€์ ์œผ๋กœ, WIKI๊ฐ€ caching์— ์œ ๋… ์ข‹์€ ๊ฒฝ์šฐ๋ฅผ ๋Œ€๋น„ํ•˜์—ฌ BOOKS corpus๋ฅผ ์ด์šฉํ•˜์—ฌ ๊ฐ™์€ ์‹คํ—˜์„ ๋ฐ˜๋ณตํ•ด ๋ณด์•˜๋‹ค. ๊ทธ ๊ฒฐ๊ณผ๋Š” ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

image

4.2 More data without Training

์ด๋ฒˆ์—๋Š” Training dataset๊ณผ Datastore์„ ๋ถ„๋ฆฌํ•˜์—ฌ, ์„œ๋กœ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ์…‹์„ ์ด์šฉํ–ˆ์„ ๋•Œ์—๋„ ํšจ๊ณผ๊ฐ€ ์žˆ๋Š”์ง€๋ฅผ ํ™•์ธํ•ด๋ณด์•˜๋‹ค.

image

์œ„ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด, WIKI-3B์™€ WIKI-100M dataset์œผ๋กœ ์‹คํ—˜ํ•ด ๋ณด์•˜๋Š”๋ฐ, ๋‹น์—ฐํ•˜๊ฒŒ๋„ ๋” ํฐ ๋ฐ์ดํ„ฐ์…‹์ธ WIKI-3B ์„ ์ด์šฉํ•ด ํ•™์Šตํ–ˆ์„ ๋•Œ ๋” ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค. ํ•˜์ง€๋งŒ, WIKI-100M์œผ๋กœ ํ•™์Šตํ•œ ๋’ค์— WIKI-3B๋ฅผ datastore๋กœ ํ™œ์šฉํ–ˆ์„ ๋•Œ, ๊ทธ ์„ฑ๋Šฅ์ด ๊ธฐ์กด์˜ LM์„ ๋Šฅ๊ฐ€ํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์•„, kNN-LM ๋ฐฉ์‹์ด ๋”์šฑ ํšจ์œจ์ ์ด๊ณ  ์ •ํ™•๋„๊ฐ€ ๋†’๋‹ค๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์—ˆ๋‹ค.

image

๋˜ํ•œ, ์œ„์™€ ๊ฐ™์ด kNN-LM์˜ datastore ํฌ๊ธฐ์— ๋Œ€ํ•ด์„œ๋„ ์‹คํ—˜์„ ํ•ด ๋ณด์•˜๋Š”๋ฐ, 1.6B์˜ dataset๋งŒ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ ์ด๋ฏธ Vanilla LM์˜ ์„ฑ๋Šฅ์„ ๋Šฅ๊ฐ€ํ•˜์˜€๊ณ , 3B์˜ ๊ฒฝ์šฐ์—๋„ ๊ทธ๋Ÿฌํ–ˆ๋‹ค. ๋” ๋‚˜์•„๊ฐ€ 3B์˜ ๊ฒฝ์šฐ์—๋„ perplexity์˜ ๊ฐ์†Œ๋„๊ฐ€ saturated(ํฌํ™”) ๋˜์ง€ ์•Š๋Š” ๊ฒƒ์„ ๋ณด์•„, ๋” ํฐ ์ž ์žฌ์„ฑ์ด ์กด์žฌํ•œ๋‹ค.

๋งˆ์ฐฌ๊ฐ€์ง€๋กœ, ๊ฐ datastore ํฌ๊ธฐ์— ๋Œ€ํ•ด์„œ optimal ํ•œ $\lambda$๋ฅผ ๊ตฌํ•ด ๋ณด์•˜์„ ๋•Œ, ์œ„ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด ๊ฒฐ๊ณผ๊ฐ€ ๋‚˜ํƒ€๋‚ฌ๋‹ค.

4.3 Domain Adaptation

Domain Adaptation ์‹คํ—˜์„ ์œ„ํ•˜์—ฌ, WIKI-3B๋กœ ํ•™์Šต๋œ ๋ชจ๋ธ์„ BOOK ์„ dataset์œผ๋กœ inference ํ•ด๋ณด์•˜๋‹ค. ๊ทธ ๊ฒฐ๊ณผ๋Š” ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

image

์ˆœ์ˆ˜ํ•˜๊ฒŒ datastore ์—†์ด ์ถ”๋ก ํ•œ ๊ฒฐ๊ณผ ๋งค์šฐ ๋‚ฎ์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€์ง€๋งŒ, BOOKS ๋ฅผ datastore ๋กœ ํ™œ์šฉํ•˜๊ฒŒ ๋˜๋ฉด, perplexity๊ฐ€ 15 ๊ฐ€๊นŒ์ด ๋–จ์–ด์ง€๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. ์ฆ‰ target domain์— ๋Œ€ํ•œ datastore์ด ์žˆ๋‹ค๋ฉด, ์ถฉ๋ถ„ํžˆ ๋‹ค๋ฅธ domain์œผ๋กœ ์ ์šฉ์ด ๊ฐ€๋Šฅํ•จ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.

5. Tuninig Nearest Neighbor Search ๐Ÿงช

Key Function

Similarity Search ๋ฅผ ์œ„ํ•˜์—ฌ, prior context๋ฅผ ํšจ๊ณผ์ ์œผ๋กœ fixed-size representation ์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” $f()$๋Š” ์ค‘์š”ํ•œ ์š”์†Œ์ด๋‹ค. ์ด๋ฅผ ์‹คํ—˜ํ•ด๋ณด๊ธฐ ์œ„ํ•ด, Transformer ๊ตฌ์กฐ์˜ ๋‹ค์–‘ํ•œ ๋ถ€๋ถ„์„ ํ›„๋ณด๋กœ ์‹คํ—˜์„ ์ง„ํ–‰ํ•˜์˜€๋‹ค. (๋ชจ๋‘ Transformer์˜ ๋งˆ์ง€๋ง‰ layer ๋ถ€๋ถ„์ด๋‹ค.)

image

์‹คํ—˜์—์„œ ๋ณด๋Š” ๊ฒƒ๊ณผ ๊ฐ™์ด, FFN input after layer norm ๋ถ€๋ถ„์ด ๊ฐ€์žฅ ๋†’์€ ์„ฑ๋Šฅ์„ ๋‚˜ํƒ€๋‚ด์—ˆ๊ณ , ์ถ”๊ฐ€์ ์œผ๋กœ ๋งˆ์ง€๋ง‰ ์ง์ „์˜ layer (second-last) ์—๋„ ์‹คํ—˜์„ ํ•ด๋ณด์•˜์œผ๋‚˜, ๋น„์Šทํ•œ ๊ฒฝํ–ฅ์˜ ์ ์ˆ˜์ง€๋งŒ ์‚ด์ง ๋‚ฎ์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค.

์ด๋ฅผ ํ†ตํ•ด FFN์€ ๋‹ค์Œ ํ† ํฐ์„ ์˜ˆ์ธกํ•˜๋Š” ๊ฒƒ์—, MHSA๋Š” representation ์— ๋”์šฑ ํšจ๊ณผ์ ์ธ ๊ฒƒ์ด๋ผ๊ณ  ์ถ”์ธกํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค.

Other elements (Number of Neighbors, Interpolation Param, etc)

image

๊ทธ๋ฆผ์—์„œ ๋ณด๋‹ค์‹œํ”ผ, k-NN ์—์„œ์˜ k ๋ฅผ ํ•˜๋‚˜์˜ ์š”์†Œ๋กœ, interpolation ๋ณ€์ˆ˜๋ฅผ ํ•˜๋‚˜์˜ ์š”์†Œ๋กœ ํ•˜์—ฌ ์‹คํ—˜์„ ์ง„ํ–‰ํ–ˆ๋‹ค. ๊ทธ ๊ฒฐ๊ณผ k ๊ฐ€ ๋Š˜์–ด๋‚ ์ˆ˜๋ก perplexity๊ฐ€ ๊ฐ์†Œํ•˜๋ฉฐ, ๊ฐ ์ƒํ™ฉ (In-domain or Domain Adaptation)์— ๋งž์ถ”์–ด ์ ์ ˆํ•จ $\lambda$ ๊ฐ’์ด ํ˜•์„ฑ๋˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

๋˜ํ•œ, Similarity Function์˜ precision์— ๋Œ€ํ•ด์„œ๋„ ์‹คํ—˜์„ ์ง„ํ–‰ํ–ˆ๋Š”๋ฐ, $L^2$ distance์— ๋Œ€ํ•ด์„œ full precision์„ ํ†ตํ•ด ์—ฐ์‚ฐํ•จ์œผ๋กœ์จ perplexity ๊ฐ€ 16.5์—์„œ 16.06์œผ๋กœ ๊ฐ์†Œํ•˜๋Š” ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค.

6. Analysis ๐Ÿง

Qualitative Analysis

$p_{kNN}$ ์ด ์™œ $p_{LM}$ ๋ณด๋‹ค ๋†’์€ ์„ฑ๋Šฅ์„ ๊ฐ€์ง€๋Š”์ง€ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•˜์—ฌ $p_{kNN}$ ์ด ๋” ๋‚˜์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋˜ ์˜ˆ์‹œ๋“ค์„ ์‚ดํŽด๋ณด๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

image

์˜ˆ์‹œ์—์„œ ๋ณด๋‹ค์‹œํ”ผ, kNN-LM model์€ rare patterns ์— ๋Œ€ํ•ด์„œ ๊ฐ€์žฅ ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€์œผ๋ฉฐ, ์ด๋Š” ๊ณง factual knowledge๋ฅผ ๋œปํ•œ๋‹ค. ํŠนํžˆ training set์— ์กด์žฌํ•˜๊ฑฐ๋‚˜ ๋น„์Šทํ•œ context์˜ ๊ฒฝ์šฐ ๋” ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค.

์ด๋ฅผ ํ†ตํ•ด parameters๋ฅผ ํ†ตํ•ด ์ง€์‹์„ implicit ํ•˜๊ฒŒ ํ•™์Šตํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค explicit ํ•˜๊ฒŒ, ์ฆ‰ Nearest Neighbor์„ ์ฐพ์•„ ํ•™์Šตํ•˜๋Š” ๊ณผ์ •์ด ๋”์šฑ ํšจ์œจ์ ์ด๋ผ๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

Simple vs Neural Representation

figure 7, 8

ํ•˜์ง€๋งŒ, ์ด๋Ÿฐ rare patterns (long-tail phenomena)๋Š” ๋‹จ์ˆœํžˆ N-gram model์—์„œ๋„ ์ถฉ๋ถ„ํžˆ ์„ฑ๋Šฅ์„ ๊ฐ€์งˆ ์ˆ˜๋„ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ n-gram ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šต์„ ์‹œ์ผœ๋ณด์•˜๋Š”๋ฐ, ๊ฒฐ๊ณผ๋Š” ์œ„์™€ ๊ฐ™์ด n-gram model์€ ํ˜„์ €ํžˆ ๋‚ฎ์€ ์ •ํ™•๋„๋ฅผ ๋‚˜ํƒ€๋‚ด์—ˆ๋‹ค. ์ด๋ฅผ ๋ฏธ๋ฃจ์–ด ๋ณด๋ฉด, kNN-LM์€ ๋‹จ์ˆœํžˆ local context๋ฅผ ํ•™์Šตํ•˜๋Š” ๊ฒƒ ์ด์ƒ์œผ๋กœ, global context๋ฅผ ํ•™์Šตํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ๋‹ค.

Implicit vs Explicit Memory

๊ทธ๋ ‡๋‹ค๋ฉด ๊ณผ์—ฐ Explicit Memory ๊ฐ€ Implicit Memory ๋ณด๋‹ค ํšจ๊ณผ์ ์ผ๊นŒ? ๋ณธ ๋…ผ๋ฌธ์—์„œ์˜ datastore ์„ LM ์ด ๋ชจ๋‘ ์™ธ์šฐ๋Š” ๊ฒƒ์ด ๊ฐ€๋Šฅํ• ๊นŒ?

์ด๋ฅผ ์‹คํ—˜ํ•ด๋ณด๊ธฐ ์œ„ํ•˜์—ฌ, Transformer ์„ dropout ์—†์ด ํ•™์Šต์‹œ์ผœ, datastore ์˜ ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šตํ•˜๋„๋ก ํ–ˆ๋‹ค. ๊ทธ ๊ฒฐ๊ณผ ์•ž์„  ๊ทธ๋ฆผ ๊ฐ™์ด, loss ๊ฐ€ ์•„์˜ˆ 0์œผ๋กœ ๋–จ์–ด์ง„ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๊ณ , ์ด๋Š” ๊ณง datastore ์˜ ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋ฅผ ์ •ํ™•ํ•˜๊ฒŒ ํ•™์Šตํ–ˆ๋‹ค๋Š” ์˜๋ฏธ๊ฐ€ ๋œ๋‹ค.

๋”ฐ๋ผ์„œ ์ด๋ ‡๊ฒŒ overfitting ๋œ LM๊ณผ explicit memory์ธ datastore๋ฅผ ๋น„๊ตํ•˜๊ธฐ ์œ„ํ•˜์—ฌ ๊ฐ๊ฐ์„ original LM ์— interpolate ํ•˜์—ฌ perplexity๋ฅผ ์ธก์ •ํ–ˆ๋Š”๋ฐ, LM์˜ ๊ฒฐ๊ณผ 0.1 ํ–ฅ์ƒ, datastore ์˜ ๊ฒฐ๊ณผ 1.9 ํ–ฅ์ƒ์œผ๋กœ explicit memory๊ฐ€ ๋” ๋›ฐ์–ด๋‚œ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค.

์ด ์‹คํ—˜์˜ ๊ฒฐ๊ณผ๋กœ, Transformer LM์€ datastore์˜ ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋ฅผ ์™ธ์šธ ์ •๋„๋กœ ๋›ฐ์–ด๋‚œ ์„ฑ๋Šฅ์„ ๋ณด์˜€์ง€๋งŒ, ์ด ๊ฒฝ์šฐ generalize ์„ฑ๋Šฅ์ด ๋–จ์–ด์ง„๋‹ค๋Š” ๊ฒƒ์„ ์ถ”์ธกํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค.

8. Conclusion & Future Work ๐ŸŽฌ

Related Work Section ์€ ์ค‘์š”ํ•œ ๋ถ€๋ถ„์ด ์•„๋‹ˆ๋ผ๊ณ  ํŒ๋‹จ๋˜์–ด ์ œ์™ธํ•˜์˜€์Šต๋‹ˆ๋‹ค.

์ด๋ ‡๋“ฏ kNN-LM model์€ ๊ธฐ์กด์˜ standard LM์„ ๋Šฅ๊ฐ€ํ•˜๋Š” ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค. ์ด๋Ÿฐ ์ ‘๊ทผ ๋ฐฉ์‹์€ ์ž„์˜์˜ NLP task ์— ์ ์šฉ๋  ์ˆ˜ ์žˆ๋‹ค. ์ด ์ ‘๊ทผ ๋ฐฉ์‹์˜ ์„ฑ๊ณต์€ ์ธํ•˜์—ฌ context ๊ฐ„์˜ similarity ๋ฅผ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์ด ๋‹ค์Œ ํ† ํฐ์„ ํ•™์Šตํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค ์‰ฌ์šด task์ž„์„ ๋œปํ•˜๊ธฐ๋„ ํ•œ๋‹ค. ์ถ”ํ›„์—๋Š” ์ด๋Ÿฐ simlarity function ์„ ๋ช…์‹œ์ ์œผ๋กœ ํ•™์Šตํ•˜๊ฑฐ๋‚˜, datastore์˜ ํฌ๊ธฐ๋ฅผ ์ค„์ด๋Š” ๊ฒƒ ๋“ฑ์˜ ์—ฐ๊ตฌ๊ฐ€ ํ•„์š”ํ•  ๊ฒƒ์ด๋‹ค.

This post is licensed under CC BY 4.0 by the author.

Trending Tags