rust-network - 解答

実装コード

TCPエコーサーバー

// src/tcp.rs
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::thread;

pub struct TcpEchoServer {
    listener: TcpListener,
}

impl TcpEchoServer {
    pub fn bind(addr: &str) -> std::io::Result<Self> {
        let listener = TcpListener::bind(addr)?;
        Ok(TcpEchoServer { listener })
    }

    pub fn run(&self) -> std::io::Result<()> {
        println!("TCP Echo Server listening on {}", self.listener.local_addr()?);

        for stream in self.listener.incoming() {
            match stream {
                Ok(stream) => {
                    thread::spawn(move || {
                        if let Err(e) = handle_client(stream) {
                            eprintln!("Client error: {}", e);
                        }
                    });
                }
                Err(e) => {
                    eprintln!("Accept error: {}", e);
                }
            }
        }
        Ok(())
    }
}

fn handle_client(mut stream: TcpStream) -> std::io::Result<()> {
    let peer = stream.peer_addr()?;
    println!("New connection from {}", peer);

    let mut buffer = [0u8; 4096];

    loop {
        let n = stream.read(&mut buffer)?;
        if n == 0 {
            println!("Connection closed: {}", peer);
            break;
        }

        stream.write_all(&buffer[..n])?;
        stream.flush()?;
    }

    Ok(())
}

// クライアント
pub struct TcpEchoClient {
    stream: TcpStream,
}

impl TcpEchoClient {
    pub fn connect(addr: &str) -> std::io::Result<Self> {
        let stream = TcpStream::connect(addr)?;
        Ok(TcpEchoClient { stream })
    }

    pub fn send(&mut self, data: &[u8]) -> std::io::Result<Vec<u8>> {
        self.stream.write_all(data)?;
        self.stream.flush()?;

        let mut buffer = vec![0u8; data.len()];
        self.stream.read_exact(&mut buffer)?;

        Ok(buffer)
    }
}

UDPチャット

// src/udp.rs
use std::collections::HashSet;
use std::net::{SocketAddr, UdpSocket};
use std::sync::{Arc, Mutex};
use std::thread;

pub struct UdpChat {
    socket: UdpSocket,
    peers: Arc<Mutex<HashSet<SocketAddr>>>,
}

impl UdpChat {
    pub fn new(port: u16) -> std::io::Result<Self> {
        let socket = UdpSocket::bind(format!("0.0.0.0:{}", port))?;
        socket.set_broadcast(true)?;

        Ok(UdpChat {
            socket,
            peers: Arc::new(Mutex::new(HashSet::new())),
        })
    }

    pub fn add_peer(&self, addr: SocketAddr) {
        self.peers.lock().unwrap().insert(addr);
    }

    pub fn send_message(&self, msg: &str) -> std::io::Result<()> {
        let peers = self.peers.lock().unwrap();
        for peer in peers.iter() {
            self.socket.send_to(msg.as_bytes(), peer)?;
        }
        Ok(())
    }

    pub fn receive(&self) -> std::io::Result<(String, SocketAddr)> {
        let mut buffer = [0u8; 65535];
        let (n, from) = self.socket.recv_from(&mut buffer)?;

        // 新しいピアを自動追加
        self.peers.lock().unwrap().insert(from);

        let msg = String::from_utf8_lossy(&buffer[..n]).to_string();
        Ok((msg, from))
    }

    pub fn run_receiver<F>(&self, callback: F)
    where
        F: Fn(String, SocketAddr) + Send + 'static,
    {
        let socket = self.socket.try_clone().unwrap();
        let peers = Arc::clone(&self.peers);

        thread::spawn(move || {
            let mut buffer = [0u8; 65535];
            loop {
                match socket.recv_from(&mut buffer) {
                    Ok((n, from)) => {
                        peers.lock().unwrap().insert(from);
                        let msg = String::from_utf8_lossy(&buffer[..n]).to_string();
                        callback(msg, from);
                    }
                    Err(e) => {
                        eprintln!("Receive error: {}", e);
                    }
                }
            }
        });
    }
}

DNSリゾルバー

// src/dns.rs
use std::net::{IpAddr, Ipv4Addr, UdpSocket};

const DNS_SERVER: &str = "8.8.8.8:53";

pub fn resolve(hostname: &str) -> std::io::Result<Vec<IpAddr>> {
    let socket = UdpSocket::bind("0.0.0.0:0")?;
    socket.set_read_timeout(Some(std::time::Duration::from_secs(5)))?;

    let query = build_dns_query(hostname);
    socket.send_to(&query, DNS_SERVER)?;

    let mut buffer = [0u8; 512];
    let (n, _) = socket.recv_from(&mut buffer)?;

    parse_dns_response(&buffer[..n])
}

fn build_dns_query(hostname: &str) -> Vec<u8> {
    let mut query = Vec::new();

    // Transaction ID
    query.extend_from_slice(&[0x00, 0x01]);

    // Flags: Standard query
    query.extend_from_slice(&[0x01, 0x00]);

    // Questions: 1
    query.extend_from_slice(&[0x00, 0x01]);

    // Answer RRs, Authority RRs, Additional RRs: 0
    query.extend_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);

    // Query name
    for part in hostname.split('.') {
        query.push(part.len() as u8);
        query.extend_from_slice(part.as_bytes());
    }
    query.push(0x00); // Null terminator

    // Type: A (1)
    query.extend_from_slice(&[0x00, 0x01]);

    // Class: IN (1)
    query.extend_from_slice(&[0x00, 0x01]);

    query
}

fn parse_dns_response(response: &[u8]) -> std::io::Result<Vec<IpAddr>> {
    let mut addrs = Vec::new();

    if response.len() < 12 {
        return Ok(addrs);
    }

    let answer_count = u16::from_be_bytes([response[6], response[7]]) as usize;

    // Skip header and question section
    let mut pos = 12;

    // Skip question
    while pos < response.len() && response[pos] != 0 {
        pos += response[pos] as usize + 1;
    }
    pos += 5; // Null byte + type + class

    // Parse answers
    for _ in 0..answer_count {
        if pos + 12 > response.len() {
            break;
        }

        // Skip name (may be compressed)
        if response[pos] & 0xC0 == 0xC0 {
            pos += 2;
        } else {
            while pos < response.len() && response[pos] != 0 {
                pos += response[pos] as usize + 1;
            }
            pos += 1;
        }

        let rtype = u16::from_be_bytes([response[pos], response[pos + 1]]);
        pos += 2;

        // Skip class
        pos += 2;

        // Skip TTL
        pos += 4;

        let rdlength = u16::from_be_bytes([response[pos], response[pos + 1]]) as usize;
        pos += 2;

        // A record (IPv4)
        if rtype == 1 && rdlength == 4 {
            let ip = Ipv4Addr::new(
                response[pos],
                response[pos + 1],
                response[pos + 2],
                response[pos + 3],
            );
            addrs.push(IpAddr::V4(ip));
        }

        pos += rdlength;
    }

    Ok(addrs)
}

テストコード

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_dns_resolve() {
        let addrs = resolve("google.com").unwrap();
        assert!(!addrs.is_empty());
    }

    #[test]
    fn test_tcp_echo() {
        use std::thread;
        use std::time::Duration;

        // サーバー起動
        thread::spawn(|| {
            let server = TcpEchoServer::bind("127.0.0.1:19000").unwrap();
            server.run().unwrap();
        });

        thread::sleep(Duration::from_millis(100));

        // クライアントテスト
        let mut client = TcpEchoClient::connect("127.0.0.1:19000").unwrap();
        let response = client.send(b"Hello").unwrap();
        assert_eq!(response, b"Hello");
    }
}