NETWORK ARCHITECTURE: MULTILAYER

The Viridthas 14.0.0 release is a significant improvement over the previous version of Viridithas, and has propelled Viridithas to global #13 on Stefan Pohl's Computer Chess rating list. As part of the improvements over version 13.0.0, Viridithas 14.0.0 comes with a novel neural network architecture. All versions of Viridithas since 3.0.0 have used neural networks, and all of these neural networks have used architectures with only a single hidden layer. This means that the network linearly transforms the features of the board into a layer of neurons, activates those neurons, and then directly calculates the value of the position from these activations. This is a very simple architecture, but it is very effective.

With Viridithas 14.0.0, I've moved beyond this. If the Viridithas 13.0.0 network looked like this:

a diagram of an NNUE architecture

Then the Viridithas 14.0.0 network looks like this:

a diagram of an improved NNUE architecture

The network now has several hidden layers, and employs a couple tricks to make this more efficient.

Pairwise multiplication

A basic fact of fully-connected neural network layers is that the vector-matrix multiplication required to calculate the activations for neurons in layer from layer has a computational cost propotional to the product of the number of neurons in the layer and the number of neurons in the layer . This means that the transition from a network to a involves increasing the computational cost of layer by a factor of . This is too high a cost - gains in evaluation quality will be offset by a loss in inference speed. It doesn't matter if you're a bit better at evaluating positions when the other guy can evaluate 16 times more than you in the same period of time.

In order to resolve this, the new architecture uses a dimensionality reduction technique called pairwise multiplication. The feature transformer generates the -element vector as usual, then we activate the vector with clipped ReLU, and then we take values in the vector and multiply them together. The simplest way to do this would be to take adjacent pairs of values, but writing efficient code becomes easier if we instead take pairs of values separated by a distance of half the width of the vector - e.g. if the vector width is 8, we multiply indices together. If you recall the results from the first ANNUEP post, you'll be aware that squared clipped ReLU is superior to clipped ReLU for chess networks, and may be questioning the choice to use CReLU for the pre-pairwise activations. Fear not - we actually maintain much of the old benefits of SCReLU. Consider that SCReLU produces activations via

whereas CReLU-into-pairwise gives us

This is almost the same as SCReLU, except that now we are taking the product of two different features!

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
for (a, acc) in [us, them].into_iter().enumerate() {
    for i in 0..L1_SIZE / 2 {
      unsafe {
        let l = *acc.get_unchecked(i);
        let r = *acc.get_unchecked(L1_SIZE / 2 + i);
        let cl = i16::clamp(l, 0, QA);
        let cr = i16::clamp(r, 0, QA);
        *output.get_unchecked_mut(i + a * L1_SIZE / 2) =
            ((i32::from(cl) * i32::from(cr)) >> FT_SHIFT) as u8;
      }
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
// Implementation of CReLU-into-pairwise, from Viridithas 19.0.0-dev.
let ft_zero = simd::zero_i16();
let ft_one = simd::splat_i16(QA);

let mut ft_outputs: Align64<[MaybeUninit<u8>; L1_SIZE]> =
    MaybeUninit::uninit().assume_init();
let mut nnz: Align64<[MaybeUninit<u16>; L1_SIZE / L1_CHUNK_PER_32]> =
    MaybeUninit::uninit().assume_init();
let mut nnz_count = 0;

let mut base = simd::v128_zero();
let increment = simd::v128_splat(8);

let mut offset = 0;
for acc in [us, them] {
    let acc_ptr = acc.as_ptr();

    for i in (0..L1_PAIR_COUNT).step_by(I16_CHUNK * 2 * 2) {
        // load the left-hand pair inputs
        let input0a = simd::load_i16(acc_ptr.add(i + 0 * I16_CHUNK));
        let input0b = simd::load_i16(acc_ptr.add(i + 1 * I16_CHUNK));
        let input0c = simd::load_i16(acc_ptr.add(i + 2 * I16_CHUNK));
        let input0d = simd::load_i16(acc_ptr.add(i + 3 * I16_CHUNK));

        // load the right-hand pair inputs
        let j = i + L1_PAIR_COUNT;
        let input1a = simd::load_i16(acc_ptr.add(j + 0 * I16_CHUNK));
        let input1b = simd::load_i16(acc_ptr.add(j + 1 * I16_CHUNK));
        let input1c = simd::load_i16(acc_ptr.add(j + 2 * I16_CHUNK));
        let input1d = simd::load_i16(acc_ptr.add(j + 3 * I16_CHUNK));

        // crelu the left-hand inputs
        let clipped0a = simd::min_i16(simd::max_i16(input0a, ft_zero), ft_one);
        let clipped0b = simd::min_i16(simd::max_i16(input0b, ft_zero), ft_one);
        let clipped0c = simd::min_i16(simd::max_i16(input0c, ft_zero), ft_one);
        let clipped0d = simd::min_i16(simd::max_i16(input0d, ft_zero), ft_one);

        // clip the right-hand inputs from above
        let clipped1a = simd::min_i16(input1a, ft_one);
        let clipped1b = simd::min_i16(input1b, ft_one);
        let clipped1c = simd::min_i16(input1c, ft_one);
        let clipped1d = simd::min_i16(input1d, ft_one);

        // shift and mulhi s.t. the high bits we get are equal to crelu(x1) * crelu(x2)
        let producta = simd::shift_mul_high_i16::<SHIFT>(clipped0a, clipped1a);
        let productb = simd::shift_mul_high_i16::<SHIFT>(clipped0b, clipped1b);
        let productc = simd::shift_mul_high_i16::<SHIFT>(clipped0c, clipped1c);
        let productd = simd::shift_mul_high_i16::<SHIFT>(clipped0d, clipped1d);

        // pack the resulting values in to u8s
        let product_one = simd::pack_i16_to_u8(producta, productb);
        let product_two = simd::pack_i16_to_u8(productc, productd);

        // store to the ft output buffer
        let ft_o_ptr = ft_outputs.as_mut_ptr();
        simd::store_u8(ft_o_ptr.add(offset + i).cast(), product_one);
        simd::store_u8(ft_o_ptr.add(offset + i + U8_CHUNK).cast(), product_two);

        // determine which parts of the result are non-zero,
        // to allow l1 propagation to happen sparsely
        let mut nnz_mask = 0;
        nnz_mask |= u32::from(simd::nonzero_mask_i32(simd::trans_i8_i32(product_one)));
        nnz_mask |= u32::from(simd::nonzero_mask_i32(simd::trans_i8_i32(product_two)))
            << NNZ_INPUT_SIMD_WIDTH;

        // store the non-zero indices into the nnz buffer
        for j in 0..NNZ_OUTPUTS_PER_CHUNK {
            let lookup = (nnz_mask >> (j * 8)) & 0xFF;
            let entry = NNZ_TABLE.table.as_ptr().add(lookup as usize);
            let offsets = simd::v128_load(entry.cast());
            simd::v128_store(
                nnz.as_mut_ptr().add(nnz_count).cast(),
                simd::v128_add(base, offsets),
            );
            nnz_count += u32::count_ones(lookup) as usize;
            base = simd::v128_add(base, increment);
        }
    }
    offset += L1_PAIR_COUNT;
}

Sparse matrix multiplication

In order to accelerate inference, we use sparse matrix multiplication. In order for this to work effectively, we need many of the activations in the feature vector to be 0. During training, we apply an L1 loss to the activations of the feature transformer output. Typically, regularisation losses like L1 and L2 are applied to the parameters of the network, but the goal here is to encourage gradient descent to minimise the number of non-zero activations in the feature transformer output1. This isn't the only reason to apply such a loss on the activation (activation L1 loss has gained elo for engines that do not use sparse matmul), but it's a good one.

An issue for sparse matrix multiplication is that the performance of network inference code relies on being able to continually fill whole SIMD registers with data. The activations of the feature transformer are single bytes, so if we tracked activation-granular sparsity we would be using quite a lot of memory for this bookkeeping. Instead, we treat each set of 4 activations as a 32-bit integer, and then compute the nonzero mask of a SIMD register full of these 32-bit integers.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// Implementations of nonzero_mask_i32. We return an
// unsigned 16-bit integer to have space for AVX512
// registers, where we can fit 16 32-bit integers in a
// single register. AVX2 and SSSE3 only use eight and four
// bits of this returned mask, respectively.

// x86-64-v2:
pub unsafe fn nonzero_mask_i32(vec: VecI32) -> u16 {
    return _mm_movemask_ps(_mm_castsi128_ps(_mm_cmpgt_epi32(vec.inner(), _mm_setzero_si128()))) as u16;
}
// x86-64-v3:
pub unsafe fn nonzero_mask_i32(vec: VecI32) -> u16 {
    return _mm256_movemask_ps(_mm256_castsi256_ps(_mm256_cmpgt_epi32(vec.inner(), _mm256_setzero_si256()))) as u16;
}
// x86-64-v4:
pub unsafe fn nonzero_mask_i32(vec: VecI32) -> u16 {
    return _mm512_cmpgt_epi32_mask(vec.inner(), _mm512_setzero_si512()) as u16;
}
// aarch64 NEON:
pub unsafe fn nonzero_mask_i32(vec: VecI32) -> u16 {
    static MASK: [u32; 4] = [1, 2, 4, 8];
    let a = std::mem::transmute(vec.inner());
    vaddvq_u32(vandq_u32(vtstq_u32(a, a), vld1q_u32(MASK.as_ptr()))) as u16
}

With these masks in hand, we then sequentially process these masks to compute the product of the feature transformer output with the weights matrix.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// Implementation of sparse matmul given non-zero indices
// using an interface that abstracts over architectures.

// &Align64<[MaybeUninit<u8>; L1_SIZE]>) -> &Align64<[i32; L1_SIZE / 4]>
let input32 = reinterpret_as_i32s(ft_outputs);
let mut sums = Align64([0; L2_SIZE]);
let nnz_count = nnz_slice.len();

let tail_start = nnz_count - (nnz_count % 4);

// affine transform
for i in (0..tail_start).step_by(4) {
    // load the block indices from the sparse index list
    let nnz_ia = *nnz_slice.get_unchecked(i + 0) as usize;
    let nnz_ib = *nnz_slice.get_unchecked(i + 1) as usize;
    let nnz_ic = *nnz_slice.get_unchecked(i + 2) as usize;
    let nnz_id = *nnz_slice.get_unchecked(i + 3) as usize;
    // load the non-zero blocks, and splat them into SIMD registers.
    let input32_a = simd::trans_i32_i8(simd::splat_i32(*input32.get_unchecked(nnz_ia)));
    let input32_b = simd::trans_i32_i8(simd::splat_i32(*input32.get_unchecked(nnz_ib)));
    let input32_c = simd::trans_i32_i8(simd::splat_i32(*input32.get_unchecked(nnz_ic)));
    let input32_d = simd::trans_i32_i8(simd::splat_i32(*input32.get_unchecked(nnz_id)));
    // compute the block indices into the weights matrix.
    let w_offset_a = nnz_ia * L2_SIZE * L1_CHUNK_PER_32;
    let w_offset_b = nnz_ib * L2_SIZE * L1_CHUNK_PER_32;
    let w_offset_c = nnz_ic * L2_SIZE * L1_CHUNK_PER_32;
    let w_offset_d = nnz_id * L2_SIZE * L1_CHUNK_PER_32;
    // for each SIMD-block in the row, compute the product
    // of the non-zero activation with the corresponding
    // weight, and add it to the accumulator.
    for k in 0..L2_SIZE / F32_CHUNK {
        let sum = simd::load_i32(sums.as_ptr().add(k * F32_CHUNK));
        let weight_a = simd::load_i8(weights.as_ptr().add(w_offset_a + k * U8_CHUNK));
        let weight_b = simd::load_i8(weights.as_ptr().add(w_offset_b + k * U8_CHUNK));
        let weight_c = simd::load_i8(weights.as_ptr().add(w_offset_c + k * U8_CHUNK));
        let weight_d = simd::load_i8(weights.as_ptr().add(w_offset_d + k * U8_CHUNK));
        let res = simd::madd_2xu8_to_i32(sum, input32_a, weight_a, input32_b, weight_b);
        let res = simd::madd_2xu8_to_i32(res, input32_c, weight_c, input32_d, weight_d);
        simd::store_i32(sums.as_mut_ptr().add(k * F32_CHUNK), res);
    }
}

// process the tail
for i in tail_start..nnz_count {
    // load the block index from the sparse index list
    let nnz_i = *nnz_slice.get_unchecked(i) as usize;
    // load the non-zero block, and splat it into a SIMD register.
    let input32 = simd::trans_i32_i8(simd::splat_i32(*input32.get_unchecked(nnz_i)));
    // compute the block index into the weights matrix.
    let w_offset = nnz_i * L2_SIZE * L1_CHUNK_PER_32;
    // for each SIMD-block in the row, compute the product
    // of the non-zero activation with the corresponding
    // weight, and add it to the accumulator.
    for k in 0..L2_SIZE / F32_CHUNK {
        let sum    = simd::load_i32(sums.as_ptr().add(k * F32_CHUNK));
        let weight = simd::load_i8(weights.as_ptr().add(w_offset + k * U8_CHUNK));
        let res    = simd::madd_u8_to_i32(sum, input32, weight);
        simd::store_i32(sums.as_mut_ptr().add(k * F32_CHUNK), res);
    }
}

// squared clipped ReLU activation
let zero    = simd::zero_f32();
let one     = simd::splat_f32(1.0);
let sum_mul = simd::splat_f32(L1_MUL);
for i in 0..L2_SIZE / F32_CHUNK {
    // convert i32 to f32, multiplying by the quantisation constant
    let bias = simd::load_f32(biases.as_ptr().add(i * F32_CHUNK));
    let unscaled = simd::i32_to_f32(simd::load_i32(sums.as_ptr().add(i * F32_CHUNK)));
    let preact   = simd::madd_f32(unscaled, sum_mul, bias);
    // activate
    let clipped = simd::min_f32(simd::max_f32(preact, zero), one);
    let squared = simd::mul_f32(clipped, clipped);
    simd::store_f32(output.as_mut_ptr().add(i * F32_CHUNK), squared);
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
// Implementation of sparse matmul given non-zero indices
// using an interface that abstracts over x86-64-v{2,3,4}.

// input32: Align64<[i32; L1_SIZE / L1_CHUNK_PER_32]>, the chunked activations.
// nnz: Align64<[u16; L1_SIZE / 4]>, the non-zero indices buffer.
// nnz_count is the number of non-zero indices in the chunked activations.
// L1_CHUNK_PER_32 is 4, as four bytes fit in one 32-bit integer.
// sums are the accumulators for the sparse matmul.
for &i in &nnz[..nnz_count] {
    // load the non-zero activation, and splat it into a SIMD register.
    let input = simd::splat_i32(input32[i]);
    // compute the index into the weights matrix.
    let i_col = i as usize * L2_SIZE * L1_CHUNK_PER_32;
    // index the row of the weights matrix, and reinterpret
    // it as an array of SIMD blocks.
    let col = std::ptr::from_ref(weights[i_col)).cast::<VecI8>();
    // for each SIMD-block in the row, compute the product
    // of the non-zero activation with the corresponding
    // weight, and add it to the accumulator.
    for k in 0..L2_SIZE / F32_CHUNK_SIZE {
        sums[k] = simd::mul_add_u8_to_i32(
            sums[k],
            simd::reinterpret_i32s_as_i8s(input),
            *col.add(k),
        );
    }
}

// Add biases, convert to floats, and run L1 activation.
let zero = simd::zero_f32();
let one = simd::splat_f32(1.0);
// L1_MUL is a factor to remove quantisation constants.
let sum_mul = simd::splat_f32(L1_MUL);
for i in 0..L2_SIZE / F32_CHUNK_SIZE {
    // Convert into floats, and activate L1
    let bias    = simd::load_f32(&biases[i * F32_CHUNK_SIZE]);
    let sum     = simd::mul_add_f32(simd::i32_to_f32(sums[i]), sum_mul, bias);
    let clipped = simd::min_f32(simd::max_f32(sum, zero), one);
    let squared = simd::mul_f32(clipped, clipped);
    simd::store_f32(&mut output[i * F32_CHUNK_SIZE], squared);
}

Full-precision multiplication

As later layers of the network are much smaller (16 to 32 neurons, instead of 4096), the cost of these later layers is much lower. As such, we can use full-precision floating-point multiplication without much computational cost, avoiding the inaccuracy of fixed-point integer quantisation.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
pub fn propagate_l2(
    inputs: &Align64<[f32; L2_SIZE]>,
    weights: &Align64<[f32; L2_SIZE * L3_SIZE]>,
    biases: &Align64<[f32; L3_SIZE]>,
    output: &mut Align64<[f32; L3_SIZE]>,
) {
  unsafe {
    let mut sums = biases.clone();

    // affine transform
    for i in 0..L2_SIZE {
        let input = *inputs.get_unchecked(i);
        for j in 0..L3_SIZE {
            let sum = *sums.get_unchecked(j);
            let w = *weights.get_unchecked(i * L3_SIZE + j);
            *sums.get_unchecked_mut(j) = input.mul_add(w, sum);
        }
    }

    // squared clipped ReLU activation
    for i in 0..L3_SIZE {
        let clipped = f32::clamp(*sums.get_unchecked(i), 0.0, 1.0);
        *output.get_unchecked_mut(i) = clipped * clipped;
    }
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
pub fn propagate_l2(
    inputs: &Align64<[f32; L2_SIZE]>,
    weights: &Align64<[f32; L2_SIZE * L3_SIZE]>,
    biases: &Align64<[f32; L3_SIZE]>,
    output: &mut Align64<[f32; L3_SIZE]>,
) {
  unsafe {
    let mut sums = biases.clone();

    // affine transform
    for i in 0..L2_SIZE {
        let activation = simd::splat_f32(*inputs.get_unchecked(i));
        for j in 0..L3_SIZE / F32_CHUNK {
            let acc = simd::load_f32(sums.as_ptr().add(j * F32_CHUNK));
            let weight = simd::load_f32(weights.as_ptr().add(i * L3_SIZE + j * F32_CHUNK));
            let res = simd::madd_f32(activation, weight, acc);
            simd::store_f32(sums.as_mut_ptr().add(j * F32_CHUNK), res);
        }
    }

    // squared clipped ReLU activation
    let zero = simd::zero_f32();
    let one = simd::splat_f32(1.0);
    for i in 0..L3_SIZE / F32_CHUNK {
        let acc = simd::load_f32(sums.as_ptr().add(i * F32_CHUNK));
        let clipped = simd::min_f32(simd::max_f32(acc, zero), one);
        let squared = simd::mul_f32(clipped, clipped);
        simd::store_f32(output.as_mut_ptr().add(i * F32_CHUNK), squared);
    }
  }
}

Thanks for reading! I hope you enjoyed this post. If you have any questions, comments, or suggestions about Viridithas, please feel free to open an issue on the GitHub repository! The next post will be about my experiments with new training targets for neural networks, and how I'm using them to improve Viridithas.


  1. As the goal is sparsity, L1 isn't quite the right regularisation loss - if it were differentiable, L0 loss would be correct. Nevertheless, it works, and L1 has additional benefits on network strength irrespective of the effects it has on non-zero activations. If you want to read about work being done on encouraging sparsity for interpreting neural language models, I recommend the JumpReLU paper from Google DeepMind.