rust-concurrency - 解答
実装コード
スレッドセーフカウンター
// src/counter.rs
use std::sync::{Arc, Mutex};
#[derive(Clone)]
pub struct Counter {
count: Arc<Mutex<u64>>,
}
impl Counter {
pub fn new() -> Self {
Counter {
count: Arc::new(Mutex::new(0)),
}
}
pub fn increment(&self) {
let mut count = self.count.lock().unwrap();
*count += 1;
}
pub fn decrement(&self) {
let mut count = self.count.lock().unwrap();
*count = count.saturating_sub(1);
}
pub fn get(&self) -> u64 {
*self.count.lock().unwrap()
}
pub fn reset(&self) {
let mut count = self.count.lock().unwrap();
*count = 0;
}
}
impl Default for Counter {
fn default() -> Self {
Self::new()
}
}
チャネルベースワーカー
// src/channel.rs
use std::sync::mpsc;
use std::thread;
pub fn spawn_workers<T, F, R>(
items: Vec<T>,
worker_count: usize,
process: F,
) -> Vec<R>
where
T: Send + 'static,
R: Send + 'static,
F: Fn(T) -> R + Send + Clone + 'static,
{
let (tx, rx) = mpsc::channel();
let (result_tx, result_rx) = mpsc::channel();
// アイテムを送信
let items_len = items.len();
thread::spawn(move || {
for item in items {
tx.send(item).unwrap();
}
});
// ワーカースレッドを起動
let rx = std::sync::Arc::new(std::sync::Mutex::new(rx));
let mut handles = Vec::new();
for _ in 0..worker_count {
let rx = Arc::clone(&rx);
let result_tx = result_tx.clone();
let process = process.clone();
let handle = thread::spawn(move || {
loop {
let item = {
let rx = rx.lock().unwrap();
rx.try_recv()
};
match item {
Ok(item) => {
let result = process(item);
result_tx.send(result).unwrap();
}
Err(mpsc::TryRecvError::Empty) => {
thread::yield_now();
}
Err(mpsc::TryRecvError::Disconnected) => break,
}
}
});
handles.push(handle);
}
// 送信側をドロップ
drop(result_tx);
// 結果を収集
let mut results = Vec::with_capacity(items_len);
for result in result_rx {
results.push(result);
}
for handle in handles {
let _ = handle.join();
}
results
}
use std::sync::Arc;
スレッドプール
// src/pool.rs
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
type Job = Box<dyn FnOnce() + Send + 'static>;
pub struct ThreadPool {
workers: Vec<Worker>,
sender: Option<mpsc::Sender<Job>>,
}
impl ThreadPool {
pub fn new(size: usize) -> Self {
assert!(size > 0, "Thread pool size must be positive");
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let mut workers = Vec::with_capacity(size);
for id in 0..size {
workers.push(Worker::new(id, Arc::clone(&receiver)));
}
ThreadPool {
workers,
sender: Some(sender),
}
}
pub fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);
if let Some(ref sender) = self.sender {
sender.send(job).unwrap();
}
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
// sender をドロップしてワーカーに終了を通知
drop(self.sender.take());
for worker in &mut self.workers {
if let Some(thread) = worker.thread.take() {
thread.join().unwrap();
}
}
}
}
struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}
impl Worker {
fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Self {
let thread = thread::spawn(move || loop {
let message = receiver.lock().unwrap().recv();
match message {
Ok(job) => {
job();
}
Err(_) => {
break;
}
}
});
Worker {
id,
thread: Some(thread),
}
}
}
並列マップ
// src/parallel.rs
use std::sync::Arc;
use std::thread;
pub fn parallel_map<T, U, F>(items: Vec<T>, f: F) -> Vec<U>
where
T: Send + 'static,
U: Send + 'static,
F: Fn(T) -> U + Send + Sync + 'static,
{
let f = Arc::new(f);
let handles: Vec<_> = items
.into_iter()
.map(|item| {
let f = Arc::clone(&f);
thread::spawn(move || f(item))
})
.collect();
handles
.into_iter()
.map(|h| h.join().unwrap())
.collect()
}
/// スレッド数を制限した並列マップ
pub fn parallel_map_n<T, U, F>(items: Vec<T>, n: usize, f: F) -> Vec<U>
where
T: Send + 'static,
U: Send + 'static,
F: Fn(T) -> U + Send + Sync + 'static,
{
use std::sync::mpsc;
let f = Arc::new(f);
let (tx, rx) = mpsc::channel();
let items_with_index: Vec<_> = items.into_iter().enumerate().collect();
let chunks: Vec<_> = items_with_index
.chunks((items_with_index.len() + n - 1) / n)
.map(|c| c.to_vec())
.collect();
for chunk in chunks {
let f = Arc::clone(&f);
let tx = tx.clone();
thread::spawn(move || {
for (idx, item) in chunk {
let result = f(item);
tx.send((idx, result)).unwrap();
}
});
}
drop(tx);
let mut results: Vec<(usize, U)> = rx.iter().collect();
results.sort_by_key(|(idx, _)| *idx);
results.into_iter().map(|(_, v)| v).collect()
}
テストコード
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_counter() {
let counter = Counter::new();
let handles: Vec<_> = (0..10)
.map(|_| {
let c = counter.clone();
thread::spawn(move || {
for _ in 0..100 {
c.increment();
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(counter.get(), 1000);
}
#[test]
fn test_thread_pool() {
let pool = ThreadPool::new(4);
let counter = Arc::new(Mutex::new(0));
for _ in 0..100 {
let c = Arc::clone(&counter);
pool.execute(move || {
let mut count = c.lock().unwrap();
*count += 1;
});
}
drop(pool); // ワーカーの終了を待つ
assert_eq!(*counter.lock().unwrap(), 100);
}
#[test]
fn test_parallel_map() {
let items: Vec<i32> = (0..100).collect();
let results = parallel_map(items, |x| x * 2);
assert_eq!(results.len(), 100);
}
}