rust-network - 解答
実装コード
TCPエコーサーバー
// src/tcp.rs
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::thread;
pub struct TcpEchoServer {
listener: TcpListener,
}
impl TcpEchoServer {
pub fn bind(addr: &str) -> std::io::Result<Self> {
let listener = TcpListener::bind(addr)?;
Ok(TcpEchoServer { listener })
}
pub fn run(&self) -> std::io::Result<()> {
println!("TCP Echo Server listening on {}", self.listener.local_addr()?);
for stream in self.listener.incoming() {
match stream {
Ok(stream) => {
thread::spawn(move || {
if let Err(e) = handle_client(stream) {
eprintln!("Client error: {}", e);
}
});
}
Err(e) => {
eprintln!("Accept error: {}", e);
}
}
}
Ok(())
}
}
fn handle_client(mut stream: TcpStream) -> std::io::Result<()> {
let peer = stream.peer_addr()?;
println!("New connection from {}", peer);
let mut buffer = [0u8; 4096];
loop {
let n = stream.read(&mut buffer)?;
if n == 0 {
println!("Connection closed: {}", peer);
break;
}
stream.write_all(&buffer[..n])?;
stream.flush()?;
}
Ok(())
}
// クライアント
pub struct TcpEchoClient {
stream: TcpStream,
}
impl TcpEchoClient {
pub fn connect(addr: &str) -> std::io::Result<Self> {
let stream = TcpStream::connect(addr)?;
Ok(TcpEchoClient { stream })
}
pub fn send(&mut self, data: &[u8]) -> std::io::Result<Vec<u8>> {
self.stream.write_all(data)?;
self.stream.flush()?;
let mut buffer = vec![0u8; data.len()];
self.stream.read_exact(&mut buffer)?;
Ok(buffer)
}
}
UDPチャット
// src/udp.rs
use std::collections::HashSet;
use std::net::{SocketAddr, UdpSocket};
use std::sync::{Arc, Mutex};
use std::thread;
pub struct UdpChat {
socket: UdpSocket,
peers: Arc<Mutex<HashSet<SocketAddr>>>,
}
impl UdpChat {
pub fn new(port: u16) -> std::io::Result<Self> {
let socket = UdpSocket::bind(format!("0.0.0.0:{}", port))?;
socket.set_broadcast(true)?;
Ok(UdpChat {
socket,
peers: Arc::new(Mutex::new(HashSet::new())),
})
}
pub fn add_peer(&self, addr: SocketAddr) {
self.peers.lock().unwrap().insert(addr);
}
pub fn send_message(&self, msg: &str) -> std::io::Result<()> {
let peers = self.peers.lock().unwrap();
for peer in peers.iter() {
self.socket.send_to(msg.as_bytes(), peer)?;
}
Ok(())
}
pub fn receive(&self) -> std::io::Result<(String, SocketAddr)> {
let mut buffer = [0u8; 65535];
let (n, from) = self.socket.recv_from(&mut buffer)?;
// 新しいピアを自動追加
self.peers.lock().unwrap().insert(from);
let msg = String::from_utf8_lossy(&buffer[..n]).to_string();
Ok((msg, from))
}
pub fn run_receiver<F>(&self, callback: F)
where
F: Fn(String, SocketAddr) + Send + 'static,
{
let socket = self.socket.try_clone().unwrap();
let peers = Arc::clone(&self.peers);
thread::spawn(move || {
let mut buffer = [0u8; 65535];
loop {
match socket.recv_from(&mut buffer) {
Ok((n, from)) => {
peers.lock().unwrap().insert(from);
let msg = String::from_utf8_lossy(&buffer[..n]).to_string();
callback(msg, from);
}
Err(e) => {
eprintln!("Receive error: {}", e);
}
}
}
});
}
}
DNSリゾルバー
// src/dns.rs
use std::net::{IpAddr, Ipv4Addr, UdpSocket};
const DNS_SERVER: &str = "8.8.8.8:53";
pub fn resolve(hostname: &str) -> std::io::Result<Vec<IpAddr>> {
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.set_read_timeout(Some(std::time::Duration::from_secs(5)))?;
let query = build_dns_query(hostname);
socket.send_to(&query, DNS_SERVER)?;
let mut buffer = [0u8; 512];
let (n, _) = socket.recv_from(&mut buffer)?;
parse_dns_response(&buffer[..n])
}
fn build_dns_query(hostname: &str) -> Vec<u8> {
let mut query = Vec::new();
// Transaction ID
query.extend_from_slice(&[0x00, 0x01]);
// Flags: Standard query
query.extend_from_slice(&[0x01, 0x00]);
// Questions: 1
query.extend_from_slice(&[0x00, 0x01]);
// Answer RRs, Authority RRs, Additional RRs: 0
query.extend_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
// Query name
for part in hostname.split('.') {
query.push(part.len() as u8);
query.extend_from_slice(part.as_bytes());
}
query.push(0x00); // Null terminator
// Type: A (1)
query.extend_from_slice(&[0x00, 0x01]);
// Class: IN (1)
query.extend_from_slice(&[0x00, 0x01]);
query
}
fn parse_dns_response(response: &[u8]) -> std::io::Result<Vec<IpAddr>> {
let mut addrs = Vec::new();
if response.len() < 12 {
return Ok(addrs);
}
let answer_count = u16::from_be_bytes([response[6], response[7]]) as usize;
// Skip header and question section
let mut pos = 12;
// Skip question
while pos < response.len() && response[pos] != 0 {
pos += response[pos] as usize + 1;
}
pos += 5; // Null byte + type + class
// Parse answers
for _ in 0..answer_count {
if pos + 12 > response.len() {
break;
}
// Skip name (may be compressed)
if response[pos] & 0xC0 == 0xC0 {
pos += 2;
} else {
while pos < response.len() && response[pos] != 0 {
pos += response[pos] as usize + 1;
}
pos += 1;
}
let rtype = u16::from_be_bytes([response[pos], response[pos + 1]]);
pos += 2;
// Skip class
pos += 2;
// Skip TTL
pos += 4;
let rdlength = u16::from_be_bytes([response[pos], response[pos + 1]]) as usize;
pos += 2;
// A record (IPv4)
if rtype == 1 && rdlength == 4 {
let ip = Ipv4Addr::new(
response[pos],
response[pos + 1],
response[pos + 2],
response[pos + 3],
);
addrs.push(IpAddr::V4(ip));
}
pos += rdlength;
}
Ok(addrs)
}
テストコード
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dns_resolve() {
let addrs = resolve("google.com").unwrap();
assert!(!addrs.is_empty());
}
#[test]
fn test_tcp_echo() {
use std::thread;
use std::time::Duration;
// サーバー起動
thread::spawn(|| {
let server = TcpEchoServer::bind("127.0.0.1:19000").unwrap();
server.run().unwrap();
});
thread::sleep(Duration::from_millis(100));
// クライアントテスト
let mut client = TcpEchoClient::connect("127.0.0.1:19000").unwrap();
let response = client.send(b"Hello").unwrap();
assert_eq!(response, b"Hello");
}
}