NETWORK ARCHITECTURE: MULTILAYER

The Viridthas 14.0.0 release is a significant improvement over the previous version of Viridithas, and has propelled Viridithas to global #13 on Stefan Pohl's Computer Chess rating list. As part of the improvements over version 13.0.0, Viridithas 14.0.0 comes with a novel neural network architecture. All versions of Viridithas since 3.0.0 have used neural networks, and all of these neural networks have used architectures with only a single hidden layer. This means that the network linearly transforms the features of the board into a layer of neurons, activates those neurons, and then directly calculates the value of the position from these activations. This is a very simple architecture, but it is very effective.

With Viridithas 14.0.0, I've moved beyond this. If the Viridithas 13.0.0 network looked like this:

a diagram of an NNUE architecture

Then the Viridithas 14.0.0 network looks like this:

a diagram of an improved NNUE architecture

The network now has several hidden layers, and employs a couple tricks to make this more efficient.

Pairwise multiplication

A basic fact of fully-connected neural network layers is that the vector-matrix multiplication required to calculate the activations for neurons in layer from layer has a computational cost propotional to the product of the number of neurons in the layer and the number of neurons in the layer . This means that the transition from a network to a involves increasing the computational cost of layer by a factor of . This is too high a cost - gains in evaluation quality will be offset by a loss in inference speed. It doesn't matter if you're a bit better at evaluating positions when the other guy can evaluate 16 times more than you in the same period of time.

In order to resolve this, the new architecture uses a dimensionality reduction technique called pairwise multiplication. The feature transformer generates the -element vector as usual, then we activate the vector with clipped ReLU, and then we take values in the vector and multiply them together. The simplest way to do this would be to take adjacent pairs of values, but writing efficient code becomes easier if we instead take pairs of values separated by a distance of half the width of the vector - e.g. if the vector width is 8, we multiply indices together. If you recall the results from the first ANNUEP post, you'll be aware that squared clipped ReLU is superior to clipped ReLU for chess networks, and may be questioning the choice to use CReLU for the pre-pairwise activations. Fear not - we actually maintain much of the old benefits of SCReLU. Consider that SCReLU produces activations via

whereas CReLU-into-pairwise gives us

This is almost the same as SCReLU, except that now we are taking the product of two different features!

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
// Implementation of CReLU-into-pairwise, from Viridithas 15.0.0-dev.

// The SIMD uses a trick to save a lot of work -
// Here's an explanation written by https://github.com/cj5716, a proper genius:
// "What we want to do is multiply inputs in a pairwise manner (after clipping),
// and then shift right by FT_SHIFT. Instead, we shift left by (16 - FT_SHIFT),
// and use mulhi, stripping the bottom 16 bits, effectively shifting right by 16,
// resulting in a net shift of FT_SHIFT bits. We use mulhi because it maintains
// the sign of the multiplication (unlike mullo), allowing us to make use of
// packus to clip 2 of the inputs, resulting in a saving of two calls to
// `vec_max_epi16`."

// vector of zeros.
let ft_zero = simd::zero_i16();
// vector of `QA`, the quantisation constant that represents a floating-point value of 1.0.
let ft_one = simd::splat_i16(QA as i16);

// The buffer that will hold the output of the pairwise multiplication.
let mut ft_outputs: Align64<[MaybeUninit<u8>; L1_SIZE]> = MaybeUninit::uninit().assume_init();

// Offset into the buffer, as we process the input in halves that are
// processed in different orders depending on the side to move.
let mut offset = 0;
for acc in [us, them] {
    for i in (0..L1_PAIR_COUNT).step_by(I16_CHUNK_SIZE * 2) {
        // load the input activations:
        let input0a = simd::load_i16(acc.get_unchecked(i + 0 + 0));
        let input0b = simd::load_i16(acc.get_unchecked(i + I16_CHUNK_SIZE + 0));
        let input1a = simd::load_i16(acc.get_unchecked(i + 0 + L1_PAIR_COUNT));
        let input1b = simd::load_i16(acc.get_unchecked(i + I16_CHUNK_SIZE + L1_PAIR_COUNT));

        // clip the inputs to [0.0, 1.0]:
        let clipped0a = simd::min_i16(simd::max_i16(input0a, ft_zero), ft_one);
        let clipped0b = simd::min_i16(simd::max_i16(input0b, ft_zero), ft_one);
        let clipped1a = simd::min_i16(input1a, ft_one);
        let clipped1b = simd::min_i16(input1b, ft_one);

        // multiply the clipped inputs, and store the result in the output buffer.
        let producta = simd::mul_high_i16(simd::shl_i16::<{ 16 - FT_SHIFT as S }>(clipped0a), clipped1a);
        let productb = simd::mul_high_i16(simd::shl_i16::<{ 16 - FT_SHIFT as S }>(clipped0b), clipped1b);
        simd::store_u8(
            std::ptr::from_mut(ft_outputs.get_unchecked_mut(offset + i)).cast(),
            simd::pack_i16_to_unsigned_and_permute(producta, productb),
        );
    }
    offset += L1_PAIR_COUNT;
}

Sparse matrix multiplication

In order to accelerate inference, we use sparse matrix multiplication. In order for this to work effectively, we need many of the activations in the feature vector to be 0. During training, we apply an L1 loss to the activations of the feature transformer output. Typically, regularisation losses like L1 and L2 are applied to the parameters of the network, but the goal here is to encourage gradient descent to minimise the number of non-zero activations in the feature transformer output[^1]. This isn't the only reason to apply such a loss on the activation (activation L1 loss has gained elo for engines that do not use sparse matmul), but it's a good one.

An issue for sparse matrix multiplication is that the performance of network inference code relies on being able to continually fill whole SIMD registers with data. The activations of the feature transformer are single bytes, so if we tracked activation-granular sparsity we would be using quite a lot of memory for this bookkeeping. Instead, we treat each set of 4 activations as a 32-bit integer, and then compute the nonzero mask of a SIMD register full of these 32-bit integers.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
// Implementations of nonzero_mask_i32 for x86-64-v2, x86-64-v3, and x86-64-v4.
// We return an unsigned 16-bit integer to have space for AVX512 registers, where
// we can fit 16 32-bit integers in a single register. AVX2 and SSSE3 only use
// eight and four bits of this returned mask, respectively. This is potentially
// an area for optimisation.
// On x86-64-v2:
pub unsafe fn nonzero_mask_i32(vec: VecI32) -> u16 {
    return _mm_movemask_ps(_mm_castsi128_ps(_mm_cmpgt_epi32(vec.inner(), _mm_setzero_si128()))) as u16;
}
// On x86-64-v3:
pub unsafe fn nonzero_mask_i32(vec: VecI32) -> u16 {
    return _mm256_movemask_ps(_mm256_castsi256_ps(_mm256_cmpgt_epi32(vec.inner(), _mm256_setzero_si256()))) as u16;
}
// On x86-64-v4:
pub unsafe fn nonzero_mask_i32(vec: VecI32) -> u16 {
    return _mm512_cmpgt_epi32_mask(vec.inner(), _mm512_setzero_si512()) as u16;
}

With these masks in hand, we then sequentially process these masks to compute the product of the feature transformer output with the weights matrix.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// Implementation of sparse matmul given non-zero indices using an interface that abstracts over x86-64-v{2,3,4}.

// input32: Align64<[i32; L1_SIZE / L1_CHUNK_PER_32]>, the chunked activations.
// nnz: Align64<[u16; L1_SIZE / 4]>, the non-zero indices buffer.
// nnz_count is the number of non-zero indices in the chunked activations.
// L1_CHUNK_PER_32 is 4, as four bytes fit in one 32-bit integer.
// sums are the accumulators for the sparse matmul.
for &i in &nnz[..nnz_count] {
    // load the non-zero activation, and splat it into a SIMD register.
    let input = simd::splat_i32(input32[i]);
    // compute the index into the weights matrix.
    let i_col = i as usize * L2_SIZE * L1_CHUNK_PER_32;
    // index the row of the weights matrix, and reinterpret
    // it as an array of SIMD blocks.
    let col = std::ptr::from_ref(weights[i_col)).cast::<VecI8>();
    // for each SIMD-block in the row, compute the product
    // of the non-zero activation with the corresponding
    // weight, and add it to the accumulator.
    for k in 0..L2_SIZE / F32_CHUNK_SIZE {
        sums[k] = simd::mul_add_u8_to_i32(
            sums[k],
            simd::reinterpret_i32s_as_i8s(input),
            *col.add(k),
        );
    }
}

// Add biases, convert to floats, and run L1 activation.
let zero = simd::zero_f32();
let one = simd::splat_f32(1.0);
// L1_MUL is a factor to remove quantisation constants.
let sum_mul = simd::splat_f32(L1_MUL);
for i in 0..L2_SIZE / F32_CHUNK_SIZE {
    // Convert into floats, and activate L1
    let bias    = simd::load_f32(&biases[i * F32_CHUNK_SIZE]);
    let sum     = simd::mul_add_f32(simd::i32_to_f32(sums[i]), sum_mul, bias);
    let clipped = simd::min_f32(simd::max_f32(sum, zero), one);
    let squared = simd::mul_f32(clipped, clipped);
    simd::store_f32(&mut output[i * F32_CHUNK_SIZE], squared);
}

[^1]: As the goal is sparsity, L1 isn't quite the right regularisation loss - if it were differentiable, L0 loss would be correct. Nevertheless, it works, and L1 has additional benefits on network strength irrespective of the effects it has on non-zero activations. If you want to read about work being done on encouraging sparsity for interpreting neural language models, I recommend the JumpReLU paper from Google DeepMind.

Full-precision multiplication

As later layers of the network are much smaller (16 to 32 neurons, instead of 4096), the cost of these later layers is much lower. As such, we can use full-precision floating-point multiplication without much computational cost, avoiding the inaccuracy of fixed-point integer quantisation.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
// Example of floating-point matrix multiply for layer 2:
pub fn propagate_l2(
    inputs:  &Align64<[f32; L2_SIZE]>,
    weights: &Align64<[f32; L2_SIZE * L3_SIZE]>,
    biases:  &Align64<[f32; L3_SIZE]>,
    output:  &mut Align64<[f32; L3_SIZE]>,
) {
    // SAFETY: we never hold multiple mutable references, we never mutate immutable memory, etc.
    unsafe {
        let mut sums = [0.0; L3_SIZE];

        // Load biases into sums.
        for i in 0..L3_SIZE / F32_CHUNK_SIZE {
            simd::store_f32(
                sums.get_unchecked_mut(i * F32_CHUNK_SIZE),
                simd::load_f32(biases.get_unchecked(i * F32_CHUNK_SIZE)),
            );
        }

        // Multiply inputs by weights, and add to sums.
        for i in 0..L2_SIZE {
            let input_vec = simd::splat_f32(*inputs.get_unchecked(i));
            for j in 0..L3_SIZE / F32_CHUNK_SIZE {
                simd::store_f32(
                    sums.get_unchecked_mut(j * F32_CHUNK_SIZE),
                    simd::mul_add_f32(
                        input_vec,
                        simd::load_f32(weights.get_unchecked(i * L3_SIZE + j * F32_CHUNK_SIZE)),
                        simd::load_f32(sums.get_unchecked(j * F32_CHUNK_SIZE)),
                    ),
                );
            }
        }

        // Activate L2 with squared clipped ReLU.
        let one = simd::splat_f32(1.0);
        for i in 0..L3_SIZE / F32_CHUNK_SIZE {
            let clipped = simd::min_f32(
                simd::max_f32(simd::load_f32(sums.get_unchecked(i * F32_CHUNK_SIZE)), simd::zero_f32()),
                one,
            );
            let squared = simd::mul_f32(clipped, clipped);
            simd::store_f32(output.get_unchecked_mut(i * F32_CHUNK_SIZE), squared);
        }
    }
}

Thanks for reading! I hope you enjoyed this post. If you have any questions, comments, or suggestions about Viridithas, please feel free to open an issue on the GitHub repository! The next post will be about my experiments with new training targets for neural networks, and how I'm using them to improve Viridithas.