Concurrent Networking Code in Rust

Concurrency is important when writing code for applications doing networking tasks, like making requests for resources over the internet. Operations like this can take a relatively long time due to the limitations of the network, but often times are not computationally expensive. We can do a whole lot of these nearly simultaneously, because most of the "work" is the application waiting for the network.

These tasks aren't precisely simultaneous: Concurrency is not Parallelism. That's an important distinction for CPU-bound tasks where we really need to squeeze every once of productivity out of the CPU cores at hand. But in the world of network programming, we can often do a large number of network-constrained actions concurrently and have it look pretty much like parallelism since so much of the time is spent waiting for data to come in over the wire.

Concurrency like this doesn't come standard with Rust, so use of community-maintained crates is necessary to include an asynchronous runtime in your application. There are multiple options here, but these examples will take an opinionated course for the sake of clarity: We'll use tokio for the async runtime & a few other "async ready" primitives, and async-channel for the worker pool example where we need a channel that can have multiple readers. Both of these crates are widely used and well maintained, so its hard to go wrong with them.

Scatter and Gather

This straightforward pattern consists of simply doing all of the things concurrently and collecting the results once they're done. This can work well if all of the things to be done are known ahead of time and there are no concerns for actually doing them at the same time. It might not work well if the work to be done is coming in on an unbounded stream (like over a TCP connection from client requests), or if only a certain number of tasks should be done at the same time (think API rate limits). Later examples will handle these restrictions.

use futures::future;
use rand::Rng;

struct Work {
    request: String,
}

struct Result {
    response: String,
}

async fn do_work(work: Work) -> Result {
    let rng = rand::thread_rng().gen_range(500..1500);
    tokio::time::sleep(std::time::Duration::from_millis(rng)).await;

    Result {
        response: format!("{}_processed", work.request),
    }
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    let work = (1..20).into_iter().map(|n| Work {
        request: format!("item_{}", n),
    });

    let future_results = work.map(|w| do_work(w));

    let results = future::join_all(future_results).await;

    for r in results {
        println!("{}", r.response);
    }

    Ok(())
}

Worker Pool

A worker pool is way to limit the number of concurrent task executions. In the Scatter and Gather approach all tasks are attempted concurrently. A fixed number of workers limits to number of concurrent tasks to the number of available workers. The code here is more involved and requires using "multiple consumer" channels from the async-channel crate.

use rand::Rng;

struct Work {
    request: String,
}

struct Result {
    response: String,
}

async fn do_work(work: Work) -> Result {
    let rng = rand::thread_rng().gen_range(500..1500);
    tokio::time::sleep(std::time::Duration::from_millis(rng)).await;

    Result {
        response: format!("{}_processed", work.request),
    }
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    let worker_count = 5;

    let (tx_work, rx_work): (async_channel::Sender<Work>, async_channel::Receiver<Work>) =
        async_channel::bounded(1);

    let (tx_result, rx_result): (
        async_channel::Sender<Result>,
        async_channel::Receiver<Result>,
    ) = async_channel::bounded(1);

    let mut workers = Vec::new();

    for _ in 0..worker_count {
        workers.push(tokio::spawn(worker(rx_work.clone(), tx_result.clone())));
    }

    let consumer = tokio::spawn(consumer(rx_result));

    for idx in 0..20 {
        tx_work
            .send(Work {
                request: format!("work_{}", idx),
            })
            .await?;
    }
    // Indicate that no more work will be sent by closing the channel. This will allow the worker loops to complete.
    drop(tx_work);

    // Wait for all workers to be done.
    futures::future::try_join_all(workers).await?;

    // Close the tx_result channel since workers will not send on that anymore. All of the workers will have exited at this point and dropped their clones of this, so dropping this last sender closes the channel.
    drop(tx_result);

    // Wait for the consumer to be done. All senders to the result channel are closed which will allow the consumer loop to end.
    consumer.await?;

    Ok(())
}

async fn worker(input: async_channel::Receiver<Work>, output: async_channel::Sender<Result>) {
    loop {
        match input.recv().await {
            Ok(work) => {
                if output.send(do_work(work).await).await.is_err() {
                    return;
                };
            }
            Err(e) => {
                println!("shutting down worker: {}", e);
                return;
            }
        }
    }
}

async fn consumer(input: async_channel::Receiver<Result>) {
    loop {
        match input.recv().await {
            Ok(result) => println!("{}", result.response),
            Err(e) => {
                println!("shutting down consumer: {}", e);
                return;
            }
        }
    }
}

Semaphore

Perhaps a simpler approach to bounded concurrency is through the use of a counting semaphore. This example using the tokio semaphore and a more standard mpsc channel with a single worker. The single worker spawns multiple concurrent tasks, but is limited to a fixed number of in-flight tasks via the semaphore.

use std::sync::Arc;

use rand::Rng;
use tokio::sync::{mpsc, Semaphore};

#[derive(Debug)]
struct Work {
    request: String,
}

#[derive(Debug)]
struct Result {
    response: String,
}

async fn do_work(work: Work) -> Result {
    let rng = rand::thread_rng().gen_range(500..1500);
    tokio::time::sleep(std::time::Duration::from_millis(rng)).await;

    Result {
        response: format!("{}_processed", work.request),
    }
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    let (tx_work, rx_work) = mpsc::channel(1);
    let (tx_result, mut rx_result) = mpsc::channel(1);

    tokio::spawn(worker(rx_work, tx_result));
    tokio::spawn(producer(tx_work));

    while let Some(result) = rx_result.recv().await {
        println!("{}", result.response);
    }

    Ok(())
}

async fn producer(tx_work: mpsc::Sender<Work>) {
    for idx in 0..20 {
        tx_work
            .send(Work {
                request: format!("work_{}", idx),
            })
            .await
            .unwrap();
    }
}

async fn worker(
    mut rx_work: mpsc::Receiver<Work>,
    tx_result: mpsc::Sender<Result>,
) -> anyhow::Result<()> {
    let semaphore = Arc::new(Semaphore::new(5));

    while let Some(work) = rx_work.recv().await {
        let permit = semaphore.clone().acquire_owned().await?;
        let tx = tx_result.clone();
        tokio::spawn(async move {
            tx.send(do_work(work).await).await.unwrap();
            drop(permit)
        });
    }

    Ok(())
}

Buffered Stream

This final example shows using a buffered stream to achieve bounded concurrent execution. The async_stream crate is used to adapt a channel into a stream. The stream is buffered which allows multiple futures it contains to be polled to completion concurrently. The stream is essentially fulfilling the role of the "worker" that was shown in previous examples. Internally it manages how many tasks can be executed concurrently.

use futures::StreamExt;
use rand::Rng;
use tokio::sync::mpsc;

#[derive(Debug)]
struct Work {
    request: String,
}

#[derive(Debug)]
struct Result {
    response: String,
}

async fn do_work(work: Work) -> Result {
    let rng = rand::thread_rng().gen_range(500..1500);
    tokio::time::sleep(std::time::Duration::from_millis(rng)).await;

    Result {
        response: format!("{}_processed", work.request),
    }
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    let (tx_work, rx_work) = mpsc::channel(100);

    let consumer = tokio::spawn(consumer(rx_work));

    for idx in 0..20 {
        tx_work
            .send(Work {
                request: format!("work_{}", idx),
            })
            .await?;
    }
    drop(tx_work);

    consumer.await?;

    Ok(())
}

async fn consumer(mut incoming: mpsc::Receiver<Work>) {
    let stream = async_stream::stream! {
        while let Some(item) = incoming.recv().await {
            yield do_work(item);
        }
    };

    let queue = stream.buffer_unordered(5);
    futures::pin_mut!(queue);

    while let Some(result) = queue.next().await {
        println!("{}_processed", result.response);
    }
}

rust

1158 Words

2023-02-21