rust-concurrency - 解答

実装コード

スレッドセーフカウンター

// src/counter.rs

use std::sync::{Arc, Mutex};

#[derive(Clone)]
pub struct Counter {
    count: Arc<Mutex<u64>>,
}

impl Counter {
    pub fn new() -> Self {
        Counter {
            count: Arc::new(Mutex::new(0)),
        }
    }

    pub fn increment(&self) {
        let mut count = self.count.lock().unwrap();
        *count += 1;
    }

    pub fn decrement(&self) {
        let mut count = self.count.lock().unwrap();
        *count = count.saturating_sub(1);
    }

    pub fn get(&self) -> u64 {
        *self.count.lock().unwrap()
    }

    pub fn reset(&self) {
        let mut count = self.count.lock().unwrap();
        *count = 0;
    }
}

impl Default for Counter {
    fn default() -> Self {
        Self::new()
    }
}

チャネルベースワーカー

// src/channel.rs

use std::sync::mpsc;
use std::thread;

pub fn spawn_workers<T, F, R>(
    items: Vec<T>,
    worker_count: usize,
    process: F,
) -> Vec<R>
where
    T: Send + 'static,
    R: Send + 'static,
    F: Fn(T) -> R + Send + Clone + 'static,
{
    let (tx, rx) = mpsc::channel();
    let (result_tx, result_rx) = mpsc::channel();

    // アイテムを送信
    let items_len = items.len();
    thread::spawn(move || {
        for item in items {
            tx.send(item).unwrap();
        }
    });

    // ワーカースレッドを起動
    let rx = std::sync::Arc::new(std::sync::Mutex::new(rx));
    let mut handles = Vec::new();

    for _ in 0..worker_count {
        let rx = Arc::clone(&rx);
        let result_tx = result_tx.clone();
        let process = process.clone();

        let handle = thread::spawn(move || {
            loop {
                let item = {
                    let rx = rx.lock().unwrap();
                    rx.try_recv()
                };

                match item {
                    Ok(item) => {
                        let result = process(item);
                        result_tx.send(result).unwrap();
                    }
                    Err(mpsc::TryRecvError::Empty) => {
                        thread::yield_now();
                    }
                    Err(mpsc::TryRecvError::Disconnected) => break,
                }
            }
        });
        handles.push(handle);
    }

    // 送信側をドロップ
    drop(result_tx);

    // 結果を収集
    let mut results = Vec::with_capacity(items_len);
    for result in result_rx {
        results.push(result);
    }

    for handle in handles {
        let _ = handle.join();
    }

    results
}

use std::sync::Arc;

スレッドプール

// src/pool.rs

use std::sync::{mpsc, Arc, Mutex};
use std::thread;

type Job = Box<dyn FnOnce() + Send + 'static>;

pub struct ThreadPool {
    workers: Vec<Worker>,
    sender: Option<mpsc::Sender<Job>>,
}

impl ThreadPool {
    pub fn new(size: usize) -> Self {
        assert!(size > 0, "Thread pool size must be positive");

        let (sender, receiver) = mpsc::channel();
        let receiver = Arc::new(Mutex::new(receiver));

        let mut workers = Vec::with_capacity(size);
        for id in 0..size {
            workers.push(Worker::new(id, Arc::clone(&receiver)));
        }

        ThreadPool {
            workers,
            sender: Some(sender),
        }
    }

    pub fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);
        if let Some(ref sender) = self.sender {
            sender.send(job).unwrap();
        }
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        // sender をドロップしてワーカーに終了を通知
        drop(self.sender.take());

        for worker in &mut self.workers {
            if let Some(thread) = worker.thread.take() {
                thread.join().unwrap();
            }
        }
    }
}

struct Worker {
    id: usize,
    thread: Option<thread::JoinHandle<()>>,
}

impl Worker {
    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Self {
        let thread = thread::spawn(move || loop {
            let message = receiver.lock().unwrap().recv();

            match message {
                Ok(job) => {
                    job();
                }
                Err(_) => {
                    break;
                }
            }
        });

        Worker {
            id,
            thread: Some(thread),
        }
    }
}

並列マップ

// src/parallel.rs

use std::sync::Arc;
use std::thread;

pub fn parallel_map<T, U, F>(items: Vec<T>, f: F) -> Vec<U>
where
    T: Send + 'static,
    U: Send + 'static,
    F: Fn(T) -> U + Send + Sync + 'static,
{
    let f = Arc::new(f);
    let handles: Vec<_> = items
        .into_iter()
        .map(|item| {
            let f = Arc::clone(&f);
            thread::spawn(move || f(item))
        })
        .collect();

    handles
        .into_iter()
        .map(|h| h.join().unwrap())
        .collect()
}

/// スレッド数を制限した並列マップ
pub fn parallel_map_n<T, U, F>(items: Vec<T>, n: usize, f: F) -> Vec<U>
where
    T: Send + 'static,
    U: Send + 'static,
    F: Fn(T) -> U + Send + Sync + 'static,
{
    use std::sync::mpsc;

    let f = Arc::new(f);
    let (tx, rx) = mpsc::channel();

    let items_with_index: Vec<_> = items.into_iter().enumerate().collect();
    let chunks: Vec<_> = items_with_index
        .chunks((items_with_index.len() + n - 1) / n)
        .map(|c| c.to_vec())
        .collect();

    for chunk in chunks {
        let f = Arc::clone(&f);
        let tx = tx.clone();
        thread::spawn(move || {
            for (idx, item) in chunk {
                let result = f(item);
                tx.send((idx, result)).unwrap();
            }
        });
    }

    drop(tx);

    let mut results: Vec<(usize, U)> = rx.iter().collect();
    results.sort_by_key(|(idx, _)| *idx);
    results.into_iter().map(|(_, v)| v).collect()
}

テストコード

#[cfg(test)]
mod tests {
    use super::*;
    use std::time::Duration;

    #[test]
    fn test_counter() {
        let counter = Counter::new();
        let handles: Vec<_> = (0..10)
            .map(|_| {
                let c = counter.clone();
                thread::spawn(move || {
                    for _ in 0..100 {
                        c.increment();
                    }
                })
            })
            .collect();

        for h in handles {
            h.join().unwrap();
        }
        assert_eq!(counter.get(), 1000);
    }

    #[test]
    fn test_thread_pool() {
        let pool = ThreadPool::new(4);
        let counter = Arc::new(Mutex::new(0));

        for _ in 0..100 {
            let c = Arc::clone(&counter);
            pool.execute(move || {
                let mut count = c.lock().unwrap();
                *count += 1;
            });
        }

        drop(pool); // ワーカーの終了を待つ
        assert_eq!(*counter.lock().unwrap(), 100);
    }

    #[test]
    fn test_parallel_map() {
        let items: Vec<i32> = (0..100).collect();
        let results = parallel_map(items, |x| x * 2);
        assert_eq!(results.len(), 100);
    }
}