解答6: 関数の実践

概要

この解答では、Zigの関数機能を総合的に実装します。基本的な数学関数から始まり、comptime関数によるコンパイル時計算、高階関数を使った関数型プログラミング、そして再帰関数を使ったアルゴリズムまで幅広くカバーします。

Part 1: 基本的な数学関数

完全な実装

const std = @import("std");

// 最大公約数(ユークリッドの互除法)
fn gcd(a: u64, b: u64) u64 {
    if (b == 0) return a;
    return gcd(b, a % b);
}

// 最小公倍数
fn lcm(a: u64, b: u64) u64 {
    if (a == 0 or b == 0) return 0;
    return (a * b) / gcd(a, b);
}

// べき乗計算(繰り返し二乗法)
fn power(base: i64, exp: u32) i64 {
    if (exp == 0) return 1;
    if (exp == 1) return base;

    // 繰り返し二乗法で効率的に計算
    var result: i64 = 1;
    var current_base = base;
    var current_exp = exp;

    while (current_exp > 0) {
        if (current_exp % 2 == 1) {
            result *= current_base;
        }
        current_base *= current_base;
        current_exp /= 2;
    }

    return result;
}

// 階乗
fn factorial(n: u64) u64 {
    if (n <= 1) return 1;
    return n * factorial(n - 1);
}

// 組み合わせ C(n, k) = n! / (k! * (n-k)!)
fn combination(n: u64, k: u64) u64 {
    if (k > n) return 0;
    if (k == 0 or k == n) return 1;

    // オーバーフロー防止のため最適化
    const k_opt = if (k > n - k) n - k else k;

    var result: u64 = 1;
    var i: u64 = 0;
    while (i < k_opt) : (i += 1) {
        result = result * (n - i) / (i + 1);
    }

    return result;
}

pub fn main() void {
    std.debug.print("=== GCD/LCM ===\n", .{});
    std.debug.print("gcd(48, 18) = {}\n", .{gcd(48, 18)});
    std.debug.print("lcm(12, 18) = {}\n", .{lcm(12, 18)});

    std.debug.print("\n=== Power ===\n", .{});
    std.debug.print("2^10 = {}\n", .{power(2, 10)});
    std.debug.print("3^5 = {}\n", .{power(3, 5)});

    std.debug.print("\n=== Factorial ===\n", .{});
    std.debug.print("5! = {}\n", .{factorial(5)});
    std.debug.print("10! = {}\n", .{factorial(10)});

    std.debug.print("\n=== Combination ===\n", .{});
    std.debug.print("C(5, 2) = {}\n", .{combination(5, 2)});
    std.debug.print("C(10, 3) = {}\n", .{combination(10, 3)});
}

実装のポイント

  • ユークリッドの互除法: 再帰的に実装し、bが0になったらaを返します。
  • 繰り返し二乗法: O(log n)の時間計算量で効率的にべき乗を計算します。
  • 組み合わせの最適化: オーバーフローを防ぐため、乗算と除算を交互に行います。
  • Part 2: comptime関数

    完全な実装

    const std = @import("std");
    
    // コンパイル時にフィボナッチ数を計算
    fn fibonacci(comptime n: u32) u64 {
        if (n == 0) return 0;
        if (n == 1) return 1;
        return fibonacci(n - 1) + fibonacci(n - 2);
    }
    
    // コンパイル時に素数判定
    fn isPrime(comptime n: u64) bool {
        if (n < 2) return false;
        if (n == 2) return true;
        if (n % 2 == 0) return false;
    
        comptime var i: u64 = 3;
        inline while (i * i <= n) : (i += 2) {
            if (n % i == 0) return false;
        }
    
        return true;
    }
    
    // ジェネリック関数: 配列の合計
    fn sum(comptime T: type, array: []const T) T {
        var total: T = 0;
        for (array) |item| {
            total += item;
        }
        return total;
    }
    
    // ジェネリック関数: 配列の平均
    fn average(comptime T: type, array: []const T) ?T {
        if (array.len == 0) return null;
    
        const total = sum(T, array);
    
        // 型に応じて除算方法を変える
        return switch (@typeInfo(T)) {
            .Int => @divTrunc(total, @as(T, @intCast(array.len))),
            .Float => total / @as(T, @floatFromInt(array.len)),
            else => @compileError("Unsupported type for average"),
        };
    }
    
    // ジェネリック関数: 配列の最大値と最小値を返す
    fn minMax(comptime T: type, array: []const T) ?struct { min: T, max: T } {
        if (array.len == 0) return null;
    
        var min_val = array[0];
        var max_val = array[0];
    
        for (array[1..]) |item| {
            if (item < min_val) min_val = item;
            if (item > max_val) max_val = item;
        }
    
        return .{ .min = min_val, .max = max_val };
    }
    
    pub fn main() void {
        std.debug.print("=== Compile-time Fibonacci ===\n", .{});
        const fib_10 = fibonacci(10);
        const fib_20 = fibonacci(20);
        std.debug.print("fib(10) = {}\n", .{fib_10});
        std.debug.print("fib(20) = {}\n", .{fib_20});
    
        std.debug.print("\n=== Compile-time Prime Check ===\n", .{});
        std.debug.print("17 is prime: {}\n", .{isPrime(17)});
        std.debug.print("18 is prime: {}\n", .{isPrime(18)});
    
        std.debug.print("\n=== Generic Sum ===\n", .{});
        const int_arr = [_]i32{ 1, 2, 3, 4, 5 };
        const float_arr = [_]f64{ 1.5, 2.5, 3.5 };
        std.debug.print("Sum of integers: {}\n", .{sum(i32, &int_arr)});
        std.debug.print("Sum of floats: {d:.1}\n", .{sum(f64, &float_arr)});
    
        std.debug.print("\n=== Generic Average ===\n", .{});
        if (average(i32, &int_arr)) |avg| {
            std.debug.print("Average: {}\n", .{avg});
        }
    
        std.debug.print("\n=== Generic Min/Max ===\n", .{});
        if (minMax(i32, &int_arr)) |result| {
            std.debug.print("Min: {}, Max: {}\n", .{result.min, result.max});
        }
    }
    

    実装のポイント

  • comptime再帰: フィボナッチ数はコンパイル時に完全に計算されます。
  • inline while: 素数判定ループはコンパイル時に展開されます。
  • 型情報を使った分岐: @typeInfoを使って整数と浮動小数点数で異なる処理を行います。
  • Part 3: 高階関数

    完全な実装

    const std = @import("std");
    
    // map関数(配列の各要素に関数を適用)
    fn map(
        comptime T: type,
        comptime R: type,
        allocator: std.mem.Allocator,
        array: []const T,
        func: *const fn(T) R,
    ) ![]R {
        const result = try allocator.alloc(R, array.len);
    
        for (array, 0..) |item, i| {
            result[i] = func(item);
        }
    
        return result;
    }
    
    // filter関数(条件を満たす要素を抽出)
    fn filter(
        comptime T: type,
        allocator: std.mem.Allocator,
        array: []const T,
        predicate: *const fn(T) bool,
    ) ![]T {
        var result = std.ArrayList(T).init(allocator);
        defer result.deinit();
    
        for (array) |item| {
            if (predicate(item)) {
                try result.append(item);
            }
        }
    
        return result.toOwnedSlice();
    }
    
    // reduce関数(配列を単一の値に集約)
    fn reduce(
        comptime T: type,
        comptime R: type,
        array: []const T,
        initial: R,
        func: *const fn(R, T) R,
    ) R {
        var accumulator = initial;
    
        for (array) |item| {
            accumulator = func(accumulator, item);
        }
    
        return accumulator;
    }
    
    // forEach関数(各要素に対して処理を実行)
    fn forEach(
        comptime T: type,
        array: []const T,
        callback: *const fn(T) void,
    ) void {
        for (array) |item| {
            callback(item);
        }
    }
    
    // テスト用の関数
    fn double(x: i32) i32 {
        return x * 2;
    }
    
    fn isEven(x: i32) bool {
        return @mod(x, 2) == 0;
    }
    
    fn add(acc: i32, x: i32) i32 {
        return acc + x;
    }
    
    fn printValue(x: i32) void {
        std.debug.print("{} ", .{x});
    }
    
    pub fn main() !void {
        const allocator = std.heap.page_allocator;
        const numbers = [_]i32{ 1, 2, 3, 4, 5 };
    
        std.debug.print("=== Map ===\n", .{});
        const doubled = try map(i32, i32, allocator, &numbers, &double);
        defer allocator.free(doubled);
        std.debug.print("Doubled: ", .{});
        for (doubled) |n| std.debug.print("{} ", .{n});
        std.debug.print("\n", .{});
    
        std.debug.print("\n=== Filter ===\n", .{});
        const evens = try filter(i32, allocator, &numbers, &isEven);
        defer allocator.free(evens);
        std.debug.print("Evens: ", .{});
        for (evens) |n| std.debug.print("{} ", .{n});
        std.debug.print("\n", .{});
    
        std.debug.print("\n=== Reduce ===\n", .{});
        const sum_result = reduce(i32, i32, &numbers, 0, &add);
        std.debug.print("Sum: {}\n", .{sum_result});
    
        std.debug.print("\n=== ForEach ===\n", .{});
        std.debug.print("Numbers: ", .{});
        forEach(i32, &numbers, &printValue);
        std.debug.print("\n", .{});
    }
    

    実装のポイント

  • メモリ管理: mapとfilterでは動的メモリを確保し、呼び出し側でfreeする責任があります。
  • ArrayList: filterでは結果のサイズが事前にわからないため、ArrayListを使用します。
  • 関数ポインタ: *const fn(T) Rの形式で関数ポインタを受け取ります。
  • Part 4: 再帰関数

    完全な実装

    const std = @import("std");
    
    // ハノイの塔
    fn hanoi(n: u32, from: []const u8, to: []const u8, aux: []const u8) void {
        if (n == 1) {
            std.debug.print("Move disk 1 from {s} to {s}\n", .{from, to});
            return;
        }
    
        // n-1枚を from -> aux
        hanoi(n - 1, from, aux, to);
    
        // 1枚を from -> to
        std.debug.print("Move disk {} from {s} to {s}\n", .{n, from, to});
    
        // n-1枚を aux -> to
        hanoi(n - 1, aux, to, from);
    }
    
    // 二分探索(再帰版)
    fn binarySearch(comptime T: type, array: []const T, target: T, left: usize, right: usize) ?usize {
        if (left > right) return null;
    
        const mid = left + (right - left) / 2;
    
        if (array[mid] == target) {
            return mid;
        } else if (array[mid] > target) {
            if (mid == 0) return null;
            return binarySearch(T, array, target, left, mid - 1);
        } else {
            return binarySearch(T, array, target, mid + 1, right);
        }
    }
    
    // クイックソート
    fn quickSort(comptime T: type, array: []T) void {
        if (array.len <= 1) return;
    
        const pivot_index = partition(T, array);
    
        if (pivot_index > 0) {
            quickSort(T, array[0..pivot_index]);
        }
        if (pivot_index + 1 < array.len) {
            quickSort(T, array[pivot_index + 1..]);
        }
    }
    
    fn partition(comptime T: type, array: []T) usize {
        const pivot = array[array.len - 1];
        var i: usize = 0;
    
        for (array[0..array.len - 1], 0..) |_, j| {
            if (array[j] <= pivot) {
                const temp = array[i];
                array[i] = array[j];
                array[j] = temp;
                i += 1;
            }
        }
    
        const temp = array[i];
        array[i] = array[array.len - 1];
        array[array.len - 1] = temp;
    
        return i;
    }
    
    // マージソート
    fn mergeSort(comptime T: type, allocator: std.mem.Allocator, array: []T) !void {
        if (array.len <= 1) return;
    
        const mid = array.len / 2;
    
        try mergeSort(T, allocator, array[0..mid]);
        try mergeSort(T, allocator, array[mid..]);
    
        try merge(T, allocator, array, mid);
    }
    
    fn merge(comptime T: type, allocator: std.mem.Allocator, array: []T, mid: usize) !void {
        const left = try allocator.dupe(T, array[0..mid]);
        defer allocator.free(left);
    
        const right = try allocator.dupe(T, array[mid..]);
        defer allocator.free(right);
    
        var i: usize = 0;
        var j: usize = 0;
        var k: usize = 0;
    
        while (i < left.len and j < right.len) {
            if (left[i] <= right[j]) {
                array[k] = left[i];
                i += 1;
            } else {
                array[k] = right[j];
                j += 1;
            }
            k += 1;
        }
    
        while (i < left.len) {
            array[k] = left[i];
            i += 1;
            k += 1;
        }
    
        while (j < right.len) {
            array[k] = right[j];
            j += 1;
            k += 1;
        }
    }
    
    // パスカルの三角形のn行目を生成
    fn pascalRow(allocator: std.mem.Allocator, n: usize) ![]u64 {
        const row = try allocator.alloc(u64, n + 1);
    
        row[0] = 1;
        if (n == 0) return row;
    
        for (1..n + 1) |i| {
            row[i] = row[i - 1] * (n - i + 1) / i;
        }
    
        return row;
    }
    
    pub fn main() !void {
        std.debug.print("=== Hanoi Tower ===\n", .{});
        hanoi(3, "A", "C", "B");
    
        std.debug.print("\n=== Binary Search ===\n", .{});
        const sorted = [_]i32{ 1, 3, 5, 7, 9, 11, 13, 15 };
        if (binarySearch(i32, &sorted, 7, 0, sorted.len - 1)) |index| {
            std.debug.print("Found 7 at index {}\n", .{index});
        }
    
        std.debug.print("\n=== Quick Sort ===\n", .{});
        var arr1 = [_]i32{ 5, 2, 8, 1, 9, 3 };
        quickSort(i32, &arr1);
        std.debug.print("Sorted: ", .{});
        for (arr1) |n| std.debug.print("{} ", .{n});
        std.debug.print("\n", .{});
    
        std.debug.print("\n=== Merge Sort ===\n", .{});
        const allocator = std.heap.page_allocator;
        var arr2 = [_]i32{ 5, 2, 8, 1, 9, 3 };
        try mergeSort(i32, allocator, &arr2);
        std.debug.print("Sorted: ", .{});
        for (arr2) |n| std.debug.print("{} ", .{n});
        std.debug.print("\n", .{});
    
        std.debug.print("\n=== Pascal's Triangle ===\n", .{});
        for (0..6) |i| {
            const row = try pascalRow(allocator, i);
            defer allocator.free(row);
            for (row) |n| std.debug.print("{} ", .{n});
            std.debug.print("\n", .{});
        }
    }
    

    実装のポイント

  • ハノイの塔: 古典的な再帰問題で、3つのステップに分割します。
  • クイックソート: 分割統治法を使い、pivot要素で配列を分割します。
  • マージソート: 安定ソートで、一時配列を使ってマージします。
  • ボーナス課題の解答

    Bonus 1: メモ化フィボナッチ

    const std = @import("std");
    
    const Memoizer = struct {
        cache: std.AutoHashMap(u64, u64),
        allocator: std.mem.Allocator,
    
        pub fn init(allocator: std.mem.Allocator) Memoizer {
            return Memoizer{
                .cache = std.AutoHashMap(u64, u64).init(allocator),
                .allocator = allocator,
            };
        }
    
        pub fn deinit(self: *Memoizer) void {
            self.cache.deinit();
        }
    
        pub fn fibonacci(self: *Memoizer, n: u64) !u64 {
            // ベースケース
            if (n <= 1) return n;
    
            // キャッシュをチェック
            if (self.cache.get(n)) |cached| {
                return cached;
            }
    
            // 計算してキャッシュに保存
            const result = try self.fibonacci(n - 1) + try self.fibonacci(n - 2);
            try self.cache.put(n, result);
    
            return result;
        }
    };
    
    pub fn main() !void {
        const allocator = std.heap.page_allocator;
        var memo = Memoizer.init(allocator);
        defer memo.deinit();
    
        std.debug.print("Memoized Fibonacci:\n", .{});
        for (0..45) |i| {
            const result = try memo.fibonacci(@intCast(i));
            std.debug.print("fib({}) = {}\n", .{i, result});
        }
    }
    

    Bonus 2: 関数合成

    const std = @import("std");
    
    // 2つの関数を合成
    fn compose(
        comptime A: type,
        comptime B: type,
        comptime C: type,
        f: *const fn(B) C,
        g: *const fn(A) B,
    ) fn(A) C {
        const ComposedFn = struct {
            fn call(x: A) C {
                return f(g(x));
            }
        };
        return ComposedFn.call;
    }
    
    // テスト用関数
    fn addOne(x: i32) i32 {
        return x + 1;
    }
    
    fn double(x: i32) i32 {
        return x * 2;
    }
    
    fn square(x: i32) i32 {
        return x * x;
    }
    
    pub fn main() void {
        // 関数合成のテスト
        const addOneThenDouble = compose(i32, i32, i32, &double, &addOne);
        const doubleThenSquare = compose(i32, i32, i32, &square, &double);
    
        std.debug.print("(addOne then double)(5) = {}\n", .{addOneThenDouble(5)});
        std.debug.print("(double then square)(5) = {}\n", .{doubleThenSquare(5)});
    }
    

    よくある間違い

    1. メモリリークの忘れ

    // 悪い例: メモリを解放していない
    const result = try map(i32, i32, allocator, &numbers, &double);
    // 使用後に解放していない!
    
    // 良い例: deferで確実に解放
    const result = try map(i32, i32, allocator, &numbers, &double);
    defer allocator.free(result);
    

    2. comptime関数の誤用

    // 悪い例: 実行時の値をcomptimeパラメータに渡そうとする
    var n: u32 = 10;
    const fib = fibonacci(n); // コンパイルエラー!
    
    // 良い例: コンパイル時定数を使用
    const n = 10;
    const fib = fibonacci(n); // OK
    

    3. 関数ポインタの型不一致

    // 悪い例: 関数シグネチャが一致していない
    fn wrongDouble(x: i64) i64 { return x * 2; }
    const result = try map(i32, i32, allocator, &numbers, &wrongDouble); // 型エラー!
    
    // 良い例: 正しい型シグネチャを使用
    fn correctDouble(x: i32) i32 { return x * 2; }
    const result = try map(i32, i32, allocator, &numbers, &correctDouble); // OK
    

    発展課題

    1. 部分適用(カリー化)の実装

    カリー化を使って、複数引数の関数を1引数の関数に変換してみましょう。

    2. 遅延評価ストリーム

    無限リストを表現する遅延評価ストリームを実装してみましょう。

    3. モナド風のチェーン操作

    Option型やResult型を使って、エラー処理を含むチェーン操作を実装してみましょう。

    4. 高速なフィボナッチ(行列累乗法)

    O(log n)の時間計算量でフィボナッチ数を計算する行列累乗法を実装してみましょう。

    まとめ

    この解答では、以下の重要な概念を学びました:

  • 数学関数の実装: アルゴリズムを理解し、効率的に実装する
  • comptime関数: コンパイル時計算で実行時のオーバーヘッドをゼロにする
  • 高階関数: 関数を引数や戻り値として扱い、抽象度を高める
  • 再帰関数: 分割統治法で複雑な問題を解決する
  • メモ化: 動的計画法で計算を最適化する

これらの技術を組み合わせることで、効率的で保守性の高いコードを書くことができます。