`timescale 1ns / 1ps
`include "tb_tools.vh"

module tb_alu ();
    `include "../rtl/alu_func.vh"

    reg  [31:0] in_a;
    reg  [31:0] in_b;
    reg  [3:0]  func;
    wire [31:0] out;

    alu alu (
        .in_a(in_a),
        .in_b(in_b),
        .func(func),
        .out(out)
    );

    initial begin
        // ALU - add
        func = ADD;
        in_a = 32'b0;
        in_b = 32'b0;
        `assert("alu : 0 + 0", out, 0)
        in_a = 32'b1;
        `assert("alu : 1 + 0", out, 1)
        in_b = 32'b1;
        `assert("alu : 1 + 1", out, 2)
        in_a = 32'b0;
        `assert("alu : 0 + 1", out, 1)
        in_a = 32'b1111;
        in_b = 32'b1111;
        `assert("alu : 15 + 15", out, 30)
        in_a = 32'b11111111111111111111111111111111;
        in_b = 32'b00000000000000000000000000000000;
        `assert("alu : 0 + -1", out, 32'b11111111111111111111111111111111)
        in_a = 32'b11111111111111111111111111111111;
        in_b = 32'b00000000000000000000000000000001;
        `assert("alu : 1 + -1", out, 0)
        in_a = 32'b10000000000000000000000000000000;
        in_b = 32'b11111111111111111111111111111111;
        `assert("alu : MIN_INT + -1", out, 32'b01111111111111111111111111111111)

        // ALU - sub
        func = SUB;
        in_a = 32'b0;
        in_b = 32'b0;
        `assert("alu : 0 - 0", out, 0)
        in_a = 32'b1;
        `assert("alu : 1 - 0", out, 1)
        in_b = 32'b1;
        `assert("alu : 1 - 1", out, 0)
        in_a = 32'b0;
        `assert("alu : 0 - 1", out, 32'b11111111111111111111111111111111)
        in_a = 32'b11111;
        in_b = 32'b1111;
        `assert("alu : 31 - 15", out, 16)
        in_a = 32'b11111111111111111111111111111111;
        in_b = 32'b00000000000000000000000000000000;
        `assert("alu : -1 - 0", out, 32'b11111111111111111111111111111111)
        in_a = 32'b11111111111111111111111111111111;
        in_b = 32'b00000000000000000000000000000001;
        `assert("alu : -1 - 1", out, 32'b11111111111111111111111111111110)
        in_a = 32'b10000000000000000000000000000000;
        in_b = 32'b11111111111111111111111111111111;
        `assert("alu : MIN_INT - -1", out, 32'b10000000000000000000000000000001)

        // ALU - left shift
        func = SLL;
        in_a = 32'b1;
        in_b = 32'b1;
        `assert("alu : 1 << 1", out, 2)
        in_b = 32'b10;
        `assert("alu : 1 << 2", out, 4)
        in_a = 32'b11;
        `assert("alu : 3 << 2", out, 12)
        in_b = 32'b11110;
        `assert("alu : 3 << 30", out, 32'b11000000000000000000000000000000)
        in_b = 32'b11111;
        `assert("alu : 3 << 31", out, 32'b10000000000000000000000000000000)
        in_b = 32'b100000;
        `assert("alu : 3 << 31", out, 32'b00000000000000000000000000000000)

        // ALU - less than
        func = SLT;
        in_a = 32'b0;
        in_b = 32'b0;
        `assert("alu : 0 < 0", out, 0)
        in_b = 32'b10;
        `assert("alu : 0 << 2", out, 1)
        in_a = 32'b11;
        `assert("alu : 3 < 2", out, 0)
        in_b = 32'b11111111111111111111111111111111;
        in_a = 32'b11111111111111111111111111111111;
        `assert("alu : -1 < -1", out, 0)
        in_b = 32'b0;
        `assert("alu : -1 < 0", out, 1)
        in_a = 32'b10000000000000000000000000000000;
        in_b = 32'b10000000000000000000000000000001;
        `assert("alu : MIN_INT << MIN_INT + 1", out, 1)

        // ALU - xor
        func = XOR;
        in_a = 32'b0;
        in_b = 32'b0;
        `assert("alu : 0 ^ 0", out, 32'b00000000000000000000000000000000)
        in_a = 32'b1;
        `assert("alu : 1 ^ 0", out, 32'b00000000000000000000000000000001)
        in_a = 32'b0;
        in_b = 32'b1;
        `assert("alu : 0 ^ 1", out, 32'b00000000000000000000000000000001)
        in_a = 32'b11111111111111111111111111111111;
        in_b = 32'b11111111111111111111111111111111;
        `assert("alu : MAX_INT ^ MAX_INT", out, 32'b00000000000000000000000000000000)
        in_a = 32'b00000000000000000000000000000000;
        in_b = 32'b11111111111111111111111111111111;
        `assert("alu : 0 ^ MAX_INT", out, 32'b11111111111111111111111111111111)
        in_a = 32'b00000011001000010001000011000000;
        in_b = 32'b10101111001011101110111111111011;
        `assert("alu : 00000011001000010001000011000000 ^ 10101111001011101110111111111011", out, 32'b10101100000011111111111100111011)

        // ALU - right shift
        func = SRL;
        in_a = 32'b1;
        in_b = 32'b1;
        `assert("alu : 1 >> 1", out, 0)
        in_a = 32'b10;
        `assert("alu : 2 >> 1", out, 1)
        in_a = 32'b11;
        `assert("alu : 3 >> 2", out, 1)
        in_a = 32'b11110;
        in_b = 32'b1;
        `assert("alu : 30 >> 1", out, 32'b1111)
        in_a = 32'b10000000000000000000000000000000;
        in_b = 32'b11111;
        `assert("alu : 1000...000 >> 31", out, 32'b00000000000000000000000000000001)
        in_a = 32'b10000000111100000000000111111111;
        in_b = 32'b11111;
        `assert("alu : 1000..111 >> 31", out, 32'b00000000000000000000000000000001)

        // ALU - arithmetic right shift
        func = SRA;
        in_a = 32'b1;
        in_b = 32'b1;
        `assert("alu : 1 >>> 1", out, 0)
        in_a = 32'b10;
        `assert("alu : 2 >>> 1", out, 1)
        in_a = 32'b11;
        `assert("alu : 3 >>> 2", out, 1)
        in_a = 32'b11110;
        in_b = 32'b1;
        `assert("alu : 30 >>> 1", out, 32'b1111)
        in_a = 32'b10000000000000000000000000000000;
        in_b = 32'b11111;
        `assert("alu : 1000...000 >>> 31", out, 32'b11111111111111111111111111111111)
        in_a = 32'b10000000111100000000000111111111;
        in_b = 32'b11111;
        `assert("alu : 1000..111 >>> 31", out, 32'b11111111111111111111111111111111)

        // ALU - or
        func = OR;
        in_a = 32'b0;
        in_b = 32'b0;
        `assert("alu : 0 | 0", out, 32'b00000000000000000000000000000000)
        in_a = 32'b1;
        `assert("alu : 1 | 0", out, 32'b00000000000000000000000000000001)
        in_a = 32'b0;
        in_b = 32'b1;
        `assert("alu : 0 | 1", out, 32'b00000000000000000000000000000001)
        in_a = 32'b11111111111111111111111111111111;
        in_b = 32'b11111111111111111111111111111111;
        `assert("alu : MAX_INT | MAX_INT", out, 32'b11111111111111111111111111111111)
        in_a = 32'b00000000000000000000000000000000;
        in_b = 32'b11111111111111111111111111111111;
        `assert("alu : 0 | MAX_INT", out, 32'b11111111111111111111111111111111)
        in_a = 32'b00000011001000010001000011000000;
        in_b = 32'b10101111001011101110111111111011;
        `assert("alu : 00000011001000010001000011000000 | 10101111001011101110111111111011", out, 32'b10101111001011111111111111111011)

        // ALU - and
        func = AND;
        in_a = 32'b0;
        in_b = 32'b0;
        `assert("alu : 0 & 0", out, 32'b00000000000000000000000000000000)
        in_a = 32'b1;
        `assert("alu : 1 & 0", out, 32'b00000000000000000000000000000000)
        in_a = 32'b0;
        in_b = 32'b1;
        `assert("alu : 0 & 1", out, 32'b00000000000000000000000000000000)
        in_a = 32'b11111111111111111111111111111111;
        in_b = 32'b11111111111111111111111111111111;
        `assert("alu : MAX_INT & MAX_INT", out, 32'b11111111111111111111111111111111)
        in_a = 32'b00000000000000000000000000000000;
        in_b = 32'b11111111111111111111111111111111;
        `assert("alu : 0 & MAX_INT", out, 32'b00000000000000000000000000000000)
        in_a = 32'b00000011001000010001000011000000;
        in_b = 32'b10101111001011101110111111111011;
        `assert("alu : 00000011001000010001000011000000 & 10101111001011101110111111111011", out, 32'b00000011001000000000000011000000)

        `end_message
    end

endmodule : tb_alu