Hierarchical processing, patch merging, and relative position bias. Prerequisite: Notebook 5 introduces W-MSA, SW-MSA, and the cyclic shift. Uses the same 32×32 toy image with 4×4 patches (8×8 patch grid, 64 patches, $d=6$, window size $W=4$).
Summary: SWIN organizes its transformer blocks into stages, each running at a fixed spatial resolution. A W-MSA/SW-MSA block pair is the fundamental repeating unit; a stage is a sequence of such pairs. Between stages, patch merging layers halve the spatial resolution and double the channel depth, building a four-scale feature hierarchy that supports both image classification and dense prediction.
Motivation. The ViT of Notebook 4 processes all patch tokens at the same spatial resolution through every transformer block — convenient for image classification, but awkward for dense prediction tasks (object detection, semantic segmentation) that rely on features at multiple scales. SWIN addresses this through a hierarchical design built around two structural ideas: block pairs and stages.
Block pairs. One W-MSA pass followed by one SW-MSA pass forms a block pair. The two always appear together: the regular-window pass (W-MSA) establishes a fixed partition; the shifted-window pass (SW-MSA) then bridges exactly the boundaries the regular pass created. A single block pair therefore gives every patch the opportunity to interact with patches up to $2W$ positions away in each direction — the full benefit of the alternating scheme from Sections 4 and 5 of Notebook 5. (The SWIN paper uses "block" to mean the single W-MSA or SW-MSA transformer layer; this notebook uses "block pair" to emphasize that the two always come as a unit.)
Stages. A stage is a sequence of one or more block pairs that all run on the same patch grid: same spatial resolution, same channel depth, same learned projection weights. The toy example in this notebook — an 8×8 patch grid with $d=6$ — represents a single stage. Stages are linked by patch merging layers (Section 2), each of which halves the spatial resolution and doubles the channel depth, handing a coarser, richer representation to the next stage.
SWIN-T's four stages. A 224×224 image with 4×4-pixel patches enters as a 56×56 patch grid with $d=96$; the table below shows how the four stages are structured. All stages use window size $W=7$, fixed throughout the network. Stage 3 does the most work, using three block pairs to process the 14×14 grid at the intermediate scale where fine spatial detail and semantic depth are both important. The other three stages each use one block pair. Other SWIN variants (Section 4) deepen the network by increasing $d$ or adding more block pairs to Stage 3, while keeping the four-stage structure.
| Stage | Patch grid | Channels $d$ | Window $W$ | Block pairs |
|---|---|---|---|---|
| 1 | 56×56 | 96 | 7 | 1 |
| 2 | 28×28 | 192 | 7 | 1 |
| 3 | 14×14 | 384 | 7 | 3 |
| 4 | 7×7 | 768 | 7 | 1 |
SWIN-T on a 224×224 image ($P=4$, initial patch grid 56×56). Each row is one stage; patch merging precedes Stages 2–4.
Relative positional bias. SWIN adds no positional encoding vector to the token embeddings — there is no sinusoidal or learned absolute PE applied at the input, as there was in Notebook 4. Spatial structure enters instead through the relative position bias added directly to the attention scores inside each attention computation at every block. Section 3 gives the full treatment.
Outputs. The four stages produce four feature maps at decreasing resolution: 56×56 (Stage 1) through 7×7 (Stage 4), with channel depth increasing at each step. For classification: Stage 4's 7×7 output is globally average-pooled to a single 768-dimensional vector, which a linear head maps to class logits — no $[\mathrm{CLS}]$ token is needed or used. For dense prediction (detection, segmentation): the global pool is skipped entirely; all four feature maps form a multi-scale pyramid passed to a task-specific head such as FPN or UPerNet. This dual-use design — a single backbone that handles both classification and dense tasks without architectural modification — is a key advantage of the hierarchical structure over plain ViT.
Summary: Between SWIN stages, adjacent $2\times2$ groups of patch tokens are concatenated and linearly projected to half their combined dimension. Spatial resolution halves and channel depth doubles at each merge, linking each stage to the next at a coarser, richer scale.
Patch merging mechanics. At each patch merging step, the $H_{\mathrm{grid}}\times W_{\mathrm{grid}}$ patch grid is divided into non-overlapping $2\times2$ groups (here $H_{\mathrm{grid}}$ and $W_{\mathrm{grid}}$ are patch counts, not to be confused with the window size $W$). The four $d$-dimensional tokens in each group are concatenated into a $4d$-dimensional vector, then a linear layer projects this to $2d$ dimensions: $$\text{merged}_i = \mathrm{concat}(p_{2r,2c},\, p_{2r,2c+1},\, p_{2r+1,2c},\, p_{2r+1,2c+1})\, W_{\text{merge}}, \qquad W_{\text{merge}} \in \mathbb{R}^{4d \times 2d}.$$ The $H_{\mathrm{grid}}\times W_{\mathrm{grid}}$ grid shrinks to $\tfrac{H_{\mathrm{grid}}}{2}\times\tfrac{W_{\mathrm{grid}}}{2}$: $N$ tokens become $\tfrac{N}{4}$ merged tokens, each carrying twice the channel depth. Similar to the $W\times W$ windowing above, patch merging operates on a spatially local $2\times2$ subset of the patch grid — but where windowing groups patches for attention, merging groups them for spatial downsampling. In the toy example, the 8×8 patch grid with $d=6$ merges to a 4×4 patch grid with $d=12$.
Merged token matrix: Rows represent 16 merged tokens from the 8×8→4×4 patch grid merge. Each row: Starting from 4 patch embeddings from a $2 \times 2$ group (each with $d=6$ dims), these embeddings are concatenated to $4d=24$ dims and then projected to $2d=12$ output dims shown here. Rows labeled by merged-grid position $(r,c)$; columns are the 12 output dims.
Summary: Rather than adding a positional encoding vector to the token embeddings, SWIN adds a small learnable bias — indexed by the relative displacement $(\Delta r,\,\Delta c)$ between query and key — directly to the attention scores inside each window. Like RoPE (Notebook 2), this encodes relative rather than absolute position and acts on the attention logits rather than the residual stream; unlike RoPE, it uses a compact learned lookup table made tractable by the fixed window size.
Three approaches to positional encoding. Notebook 2 introduced two strategies. Sinusoidal PE (Section 4) adds an absolute position vector to the token embeddings before the first block — each token carries "I am at position $n$," and that signal propagates through the residual stream for the rest of the forward pass. RoPE (Section 5, now the dominant approach in language models) takes a different tack: rather than modifying embeddings, it rotates the Q and K vectors inside each attention head so that their dot product naturally depends on the relative distance between tokens — no embedding modification, no residual stream effect. Notebook 4's ViT also used sinusoidal PE added to embeddings.
SWIN takes a third approach. No positional vector is ever added to the token embeddings. Instead, a small learnable bias is added directly to the attention scores at every block and every stage, indexed by the offset $(\Delta r, \Delta c)$ between a query patch and a key patch — not by either patch's absolute grid position. The same bias table is shared across all windows at a given stage and applied fresh at every attention computation, rather than being baked into the token embeddings once.
Comparison to RoPE. Of the three approaches, SWIN's bias is philosophically closest to RoPE: both encode relative position, and both inject that information into the attention scores rather than the token embeddings or the residual stream. The key differences are mechanism and scope. RoPE uses a parameter-free geometric construction — rotating Q and K in a way that makes their dot product depend on relative position — and works for any token-pair distance, applied globally across all pairs. SWIN instead uses a compact learned lookup table: because all windows share the same fixed size $W\times W$, the set of possible offsets $(\Delta r, \Delta c)$ with $\Delta r, \Delta c \in [-(W-1),\,W-1]$ is finite, making a $(2W-1)\times(2W-1)$ table both sufficient and efficient to learn.
The bias table. SWIN learns a bias table with entries $b_{ij}$ for $i, j \in \{-(W-1), \ldots, W-1\}$; for $W=4$ this is a $7\times7 = 49$-parameter table indexed directly by displacement: $i = \Delta r$ and $j = \Delta c$, both ranging from $-(W-1)$ to $W-1$. These table entries assemble into a $W^2\times W^2$ bias matrix: for tokens $p$ and $q$ at local grid positions $(r_p, c_p)$ and $(r_q, c_q)$ within a window, the $(p,q)$ entry is given by $$B_{pq} = b_{\,r_p - r_q,\; c_p - c_q}.$$ The modified attention computation becomes: $$\mathrm{Attention}(Q,K,V) = \mathrm{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}} + B\right)\!V,$$ where $B$ is shared across all windows at a given stage. Importantly, this bias is added to attention scores — not to the token embeddings — at a different point in the computation than sinusoidal PE.
Bias table $b_{ij}$ (7×7). Rows: $i = \Delta r \in [-3,3]$; cols: $j = \Delta c \in [-3,3]$. Peak at $(0,0)$; decays with Manhattan distance — a typical learned pattern (here hand-designed for illustration).
Assembled bias matrix $B$ (16×16) for one window ($W=4$). Entry $B_{pq} = b_{\Delta r,\,\Delta c}$ where $\Delta r = r_p - r_q$, $\Delta c = c_p - c_q$. Rows/columns labeled by local patch position (row, col). Block structure reflects the row/column decomposition of the offsets.
The full SWIN architecture assembles patch embedding, W-MSA/SW-MSA block pairs, patch merging, and relative position bias into a four-stage hierarchical backbone for image classification, object detection, and semantic segmentation.
Input image e.g. 224×224×3 |
↓ | Patch embed (4×4 px), linear proj → 56×56 × 96 |
Stage 1: W-MSA + SW-MSA + rel. pos. bias |
↓ | 56×56 × 96, 2 blocks |
Patch Merge |
↓ | 28×28 × 192 |
Stage 2: W-MSA + SW-MSA + rel. pos. bias |
↓ | 28×28 × 192, 2 blocks |
Patch Merge |
↓ | 14×14 × 384 |
Stage 3: W-MSA + SW-MSA (×3) + rel. pos. bias |
↓ | 14×14 × 384, 6 blocks |
Patch Merge |
↓ | 7×7 × 768 |
Stage 4: W-MSA + SW-MSA + rel. pos. bias |
↓ | 7×7 × 768, 2 blocks |
|
↓
Global Avg Pool
→ Linear head ↓
Class logits
Classification
↓
4-scale feature pyramid
↓
FPN / UPerNet
Detection / Segmentation
|
SWIN variants. The four standard SWIN variants share the same four-stage structure and differ only in $d$ (initial channel width) and the number of block pairs per stage. Stage 3 varies most: SWIN-T uses three pairs while SWIN-S/B/L use nine. All other stages use one pair in every variant.
| Variant | $d$ (Stage 1) | Block pairs (S1, S2, S3, S4) | Params |
|---|---|---|---|
| SWIN-T (Tiny) | 96 | 1, 1, 3, 1 | ~28M |
| SWIN-S (Small) | 96 | 1, 1, 9, 1 | ~50M |
| SWIN-B (Base) | 128 | 1, 1, 9, 1 | ~88M |
| SWIN-L (Large) | 192 | 1, 1, 9, 1 | ~197M |