1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use std::ops::{Add, Mul, BitAnd, BitOr, BitXor};
use tensor::{Tensor, Full, Index};
use traits::{NumericTrait, TensorTrait};

impl<T: NumericTrait> Tensor<T> {
    pub fn max(&self) -> T {
        assert!(self.size() > 0, "Can't take max of empty tensor");
        let mut m = T::zero();
        for (i, v) in self.iter().enumerate() {
            if i == 0 {
                m = v;
            } else if v > m {
                m = v;
            }
        }
        m
    }

    pub fn min(&self) -> T {
        assert!(self.size() > 0, "Can't take min of empty tensor");
        let mut m = T::zero();
        for (i, v) in self.iter().enumerate() {
            if i == 0 {
                m = v;
            } else if v < m {
                m = v;
            }
        }
        m
    }

    pub fn sum(&self) -> T {
        let mut s = T::zero();
        for v in self.iter() {
            s = s + v;
        }
        s
    }

    pub fn mean(&self) -> T {
        let mut s = T::zero();
        let mut t = T::zero();
        for v in self.iter() {
            s = s + v;
            t = t + T::one();
        }
        s / t
    }
}

macro_rules! add_impl {
    ($trait_name:ident, $func_name:ident, $new_func_name:ident) => (
        impl<T: TensorTrait + $trait_name<Output=T>> Tensor<T> {
            pub fn $new_func_name(&self, axis: usize) -> Tensor<T> {
                assert!(axis < self.ndim(), "Reduced axis must exist");
                let mut sel = vec![Full; axis];
                sel.push(Index(0));

                let mut t = self.index(&sel[..]);
                let d = self.dim(axis);
                for i in 1..d {
                    sel[axis] = Index(i as isize);
                    t = (&t).$func_name(&self.index(&sel[..]));
                }

                t
            }
        }
    )
}

add_impl!(Add, add, sum_axis);
add_impl!(Mul, mul, prod_axis);

add_impl!(BitAnd, bitand, bitand_axis);
add_impl!(BitOr, bitor, bitor_axis);
add_impl!(BitXor, bitxor, bitxor_axis);