1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
use tensor::Tensor;
use std::cmp::min;
use num::traits::Zero;
mod solve;
mod svd;
pub fn diag<T: Copy + Zero>(a: &Tensor<T>) -> Tensor<T> {
assert!(a.ndim() == 1 || a.ndim() == 2, "Can only run diag for vectors and matrices");
if a.ndim() == 1 {
let mut b = Tensor::zeros(&[a.size(), a.size()]);
for i in 0..a.size() {
b[(i, i)] = a[(i,)];
}
b
} else {
let mn = min(a.dim(0), a.dim(1));
let mut b = Tensor::zeros(&[mn]);
for i in 0..mn {
b[(i,)] = a[(i, i)];
}
b
}
}