1
1
use crate :: { FromCv , TryFromCv } ;
2
2
use anyhow:: { Error , Result } ;
3
- use ndarray as nd;
4
3
5
4
use to_ndarray_shape:: * ;
6
5
mod to_ndarray_shape {
7
6
use super :: * ;
8
7
9
8
pub trait ToNdArrayShape < D >
10
9
where
11
- Self :: Output : Sized + Into < nd :: StrideShape < D > > ,
10
+ Self :: Output : Sized + Into < ndarray :: StrideShape < D > > ,
12
11
{
13
12
type Output ;
14
13
type Error ;
15
14
16
15
fn to_ndarray_shape ( & self ) -> Result < Self :: Output , Self :: Error > ;
17
16
}
18
17
19
- impl ToNdArrayShape < nd :: IxDyn > for Vec < i64 > {
18
+ impl ToNdArrayShape < ndarray :: IxDyn > for Vec < i64 > {
20
19
type Output = Vec < usize > ;
21
20
type Error = Error ;
22
21
@@ -26,7 +25,7 @@ mod to_ndarray_shape {
26
25
}
27
26
}
28
27
29
- impl ToNdArrayShape < nd :: Ix0 > for Vec < i64 > {
28
+ impl ToNdArrayShape < ndarray :: Ix0 > for Vec < i64 > {
30
29
type Output = [ usize ; 0 ] ;
31
30
type Error = Error ;
32
31
@@ -40,7 +39,7 @@ mod to_ndarray_shape {
40
39
}
41
40
}
42
41
43
- impl ToNdArrayShape < nd :: Ix1 > for Vec < i64 > {
42
+ impl ToNdArrayShape < ndarray :: Ix1 > for Vec < i64 > {
44
43
type Output = [ usize ; 1 ] ;
45
44
type Error = Error ;
46
45
@@ -53,7 +52,7 @@ mod to_ndarray_shape {
53
52
}
54
53
}
55
54
56
- impl ToNdArrayShape < nd :: Ix2 > for Vec < i64 > {
55
+ impl ToNdArrayShape < ndarray :: Ix2 > for Vec < i64 > {
57
56
type Output = [ usize ; 2 ] ;
58
57
type Error = Error ;
59
58
@@ -66,7 +65,7 @@ mod to_ndarray_shape {
66
65
}
67
66
}
68
67
69
- impl ToNdArrayShape < nd :: Ix3 > for Vec < i64 > {
68
+ impl ToNdArrayShape < ndarray :: Ix3 > for Vec < i64 > {
70
69
type Output = [ usize ; 3 ] ;
71
70
type Error = Error ;
72
71
@@ -79,7 +78,7 @@ mod to_ndarray_shape {
79
78
}
80
79
}
81
80
82
- impl ToNdArrayShape < nd :: Ix4 > for Vec < i64 > {
81
+ impl ToNdArrayShape < ndarray :: Ix4 > for Vec < i64 > {
83
82
type Output = [ usize ; 4 ] ;
84
83
type Error = Error ;
85
84
@@ -92,7 +91,7 @@ mod to_ndarray_shape {
92
91
}
93
92
}
94
93
95
- impl ToNdArrayShape < nd :: Ix5 > for Vec < i64 > {
94
+ impl ToNdArrayShape < ndarray :: Ix5 > for Vec < i64 > {
96
95
type Output = [ usize ; 5 ] ;
97
96
type Error = Error ;
98
97
@@ -111,7 +110,7 @@ mod to_ndarray_shape {
111
110
}
112
111
}
113
112
114
- impl ToNdArrayShape < nd :: Ix6 > for Vec < i64 > {
113
+ impl ToNdArrayShape < ndarray :: Ix6 > for Vec < i64 > {
115
114
type Output = [ usize ; 6 ] ;
116
115
type Error = Error ;
117
116
@@ -132,9 +131,9 @@ mod to_ndarray_shape {
132
131
}
133
132
}
134
133
135
- impl < A , D > TryFromCv < tch:: Tensor > for nd :: Array < A , D >
134
+ impl < A , D > TryFromCv < tch:: Tensor > for ndarray :: Array < A , D >
136
135
where
137
- D : nd :: Dimension ,
136
+ D : ndarray :: Dimension ,
138
137
A : tch:: kind:: Element ,
139
138
Vec < A > : TryFrom < tch:: Tensor , Error = tch:: TchError > ,
140
139
Vec < i64 > : ToNdArrayShape < D , Error = Error > ,
@@ -158,9 +157,9 @@ where
158
157
}
159
158
}
160
159
161
- impl < A , D > TryFromCv < & tch:: Tensor > for nd :: Array < A , D >
160
+ impl < A , D > TryFromCv < & tch:: Tensor > for ndarray :: Array < A , D >
162
161
where
163
- D : nd :: Dimension ,
162
+ D : ndarray :: Dimension ,
164
163
A : tch:: kind:: Element ,
165
164
Vec < A > : TryFrom < tch:: Tensor , Error = tch:: TchError > ,
166
165
Vec < i64 > : ToNdArrayShape < D , Error = Error > ,
@@ -172,13 +171,13 @@ where
172
171
}
173
172
}
174
173
175
- impl < A , S , D > FromCv < & nd :: ArrayBase < S , D > > for tch:: Tensor
174
+ impl < A , S , D > FromCv < & ndarray :: ArrayBase < S , D > > for tch:: Tensor
176
175
where
177
176
A : tch:: kind:: Element + Clone ,
178
- S : nd :: RawData < Elem = A > + nd :: Data ,
179
- D : nd :: Dimension ,
177
+ S : ndarray :: RawData < Elem = A > + ndarray :: Data ,
178
+ D : ndarray :: Dimension ,
180
179
{
181
- fn from_cv ( from : & nd :: ArrayBase < S , D > ) -> Self {
180
+ fn from_cv ( from : & ndarray :: ArrayBase < S , D > ) -> Self {
182
181
let shape: Vec < _ > = from. shape ( ) . iter ( ) . map ( |& s| s as i64 ) . collect ( ) ;
183
182
184
183
match from. as_slice ( ) {
@@ -191,13 +190,13 @@ where
191
190
}
192
191
}
193
192
194
- impl < A , S , D > FromCv < nd :: ArrayBase < S , D > > for tch:: Tensor
193
+ impl < A , S , D > FromCv < ndarray :: ArrayBase < S , D > > for tch:: Tensor
195
194
where
196
195
A : tch:: kind:: Element + Clone ,
197
- S : nd :: RawData < Elem = A > + nd :: Data ,
198
- D : nd :: Dimension ,
196
+ S : ndarray :: RawData < Elem = A > + ndarray :: Data ,
197
+ D : ndarray :: Dimension ,
199
198
{
200
- fn from_cv ( from : nd :: ArrayBase < S , D > ) -> Self {
199
+ fn from_cv ( from : ndarray :: ArrayBase < S , D > ) -> Self {
201
200
Self :: from_cv ( & from)
202
201
}
203
202
}
@@ -218,7 +217,7 @@ mod tests {
218
217
let s2 = 5 ;
219
218
220
219
let tensor = tch:: Tensor :: randn ( [ s0, s1, s2] , tch:: kind:: FLOAT_CPU ) ;
221
- let array: nd :: ArrayD < f32 > = ( & tensor) . try_into_cv ( ) ?;
220
+ let array: ndarray :: ArrayD < f32 > = ( & tensor) . try_into_cv ( ) ?;
222
221
223
222
let is_correct = itertools:: iproduct!( 0 ..s0, 0 ..s1, 0 ..s2) . all ( |( i0, i1, i2) | {
224
223
let lhs: f32 = tensor. i ( ( i0, i1, i2) ) . try_into ( ) . unwrap ( ) ;
@@ -232,7 +231,7 @@ mod tests {
232
231
// Array0
233
232
{
234
233
let tensor = tch:: Tensor :: randn ( [ ] , tch:: kind:: FLOAT_CPU ) ;
235
- let array: nd :: Array0 < f32 > = ( & tensor) . try_into_cv ( ) ?;
234
+ let array: ndarray :: Array0 < f32 > = ( & tensor) . try_into_cv ( ) ?;
236
235
let lhs: f32 = tensor. try_into ( ) . unwrap ( ) ;
237
236
let rhs = array[ ( ) ] ;
238
237
anyhow:: ensure!( lhs == rhs, "value mismatch" ) ;
@@ -242,7 +241,7 @@ mod tests {
242
241
{
243
242
let s0 = 10 ;
244
243
let tensor = tch:: Tensor :: randn ( [ s0] , tch:: kind:: FLOAT_CPU ) ;
245
- let array: nd :: Array1 < f32 > = ( & tensor) . try_into_cv ( ) ?;
244
+ let array: ndarray :: Array1 < f32 > = ( & tensor) . try_into_cv ( ) ?;
246
245
247
246
let is_correct = ( 0 ..s0) . all ( |ind| {
248
247
let lhs: f32 = tensor. i ( ( ind, ) ) . try_into ( ) . unwrap ( ) ;
@@ -259,7 +258,7 @@ mod tests {
259
258
let s1 = 5 ;
260
259
261
260
let tensor = tch:: Tensor :: randn ( [ s0, s1] , tch:: kind:: FLOAT_CPU ) ;
262
- let array: nd :: Array2 < f32 > = ( & tensor) . try_into_cv ( ) ?;
261
+ let array: ndarray :: Array2 < f32 > = ( & tensor) . try_into_cv ( ) ?;
263
262
264
263
let is_correct = itertools:: iproduct!( 0 ..s0, 0 ..s1) . all ( |( i0, i1) | {
265
264
let lhs: f32 = tensor. i ( ( i0, i1) ) . try_into ( ) . unwrap ( ) ;
@@ -277,7 +276,7 @@ mod tests {
277
276
let s2 = 7 ;
278
277
279
278
let tensor = tch:: Tensor :: randn ( [ s0, s1, s2] , tch:: kind:: FLOAT_CPU ) ;
280
- let array: nd :: Array3 < f32 > = ( & tensor) . try_into_cv ( ) ?;
279
+ let array: ndarray :: Array3 < f32 > = ( & tensor) . try_into_cv ( ) ?;
281
280
282
281
let is_correct = itertools:: iproduct!( 0 ..s0, 0 ..s1, 0 ..s2) . all ( |( i0, i1, i2) | {
283
282
let lhs: f32 = tensor. i ( ( i0, i1, i2) ) . try_into ( ) . unwrap ( ) ;
@@ -296,7 +295,7 @@ mod tests {
296
295
let s3 = 11 ;
297
296
298
297
let tensor = tch:: Tensor :: randn ( [ s0, s1, s2, s3] , tch:: kind:: FLOAT_CPU ) ;
299
- let array: nd :: Array4 < f32 > = ( & tensor) . try_into_cv ( ) ?;
298
+ let array: ndarray :: Array4 < f32 > = ( & tensor) . try_into_cv ( ) ?;
300
299
301
300
let is_correct =
302
301
itertools:: iproduct!( 0 ..s0, 0 ..s1, 0 ..s2, 0 ..s3) . all ( |( i0, i1, i2, i3) | {
@@ -317,7 +316,7 @@ mod tests {
317
316
let s4 = 13 ;
318
317
319
318
let tensor = tch:: Tensor :: randn ( [ s0, s1, s2, s3, s4] , tch:: kind:: FLOAT_CPU ) ;
320
- let array: nd :: Array5 < f32 > = ( & tensor) . try_into_cv ( ) ?;
319
+ let array: ndarray :: Array5 < f32 > = ( & tensor) . try_into_cv ( ) ?;
321
320
322
321
let is_correct = itertools:: iproduct!( 0 ..s0, 0 ..s1, 0 ..s2, 0 ..s3, 0 ..s4) . all (
323
322
|( i0, i1, i2, i3, i4) | {
@@ -346,7 +345,7 @@ mod tests {
346
345
let s5 = 17 ;
347
346
348
347
let tensor = tch:: Tensor :: randn ( [ s0, s1, s2, s3, s4, s5] , tch:: kind:: FLOAT_CPU ) ;
349
- let array: nd :: Array6 < f32 > = ( & tensor) . try_into_cv ( ) ?;
348
+ let array: ndarray :: Array6 < f32 > = ( & tensor) . try_into_cv ( ) ?;
350
349
351
350
let is_correct = itertools:: iproduct!( 0 ..s0, 0 ..s1, 0 ..s2, 0 ..s3, 0 ..s4, 0 ..s5) . all (
352
351
|( i0, i1, i2, i3, i4, i5) | {
@@ -377,7 +376,7 @@ mod tests {
377
376
let s1 = 3 ;
378
377
let s2 = 4 ;
379
378
380
- let array = nd :: Array3 :: < f32 > :: from_shape_simple_fn ( [ s0, s1, s2] , || rng. random ( ) ) ;
379
+ let array = ndarray :: Array3 :: < f32 > :: from_shape_simple_fn ( [ s0, s1, s2] , || rng. random ( ) ) ;
381
380
let array = array. reversed_axes ( ) ;
382
381
383
382
let tensor = tch:: Tensor :: from_cv ( & array) ;
0 commit comments