Skip to content

Commit 3b7598e

Browse files
authored
Added to torch_sparse_csc_tensor (#319)
update
1 parent e6cf558 commit 3b7598e

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

torch_sparse/storage.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,9 @@ def sparse_resize(self, sparse_sizes: Tuple[int, int]):
282282
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
283283
elif diff_0 < 0:
284284
if rowptr is not None:
285-
rowptr = rowptr[:-diff_0]
285+
rowptr = rowptr[:diff_0]
286286
if rowcount is not None:
287-
rowcount = rowcount[:-diff_0]
287+
rowcount = rowcount[:diff_0]
288288

289289
diff_1 = sparse_sizes[1] - old_sparse_sizes[1]
290290
colcount, colptr = self._colcount, self._colptr
@@ -295,9 +295,9 @@ def sparse_resize(self, sparse_sizes: Tuple[int, int]):
295295
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
296296
elif diff_1 < 0:
297297
if colptr is not None:
298-
colptr = colptr[:-diff_1]
298+
colptr = colptr[:diff_1]
299299
if colcount is not None:
300-
colcount = colcount[:-diff_1]
300+
colcount = colcount[:diff_1]
301301

302302
return SparseStorage(
303303
row=self._row,

torch_sparse/tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,15 @@ def to_torch_sparse_csr_tensor(
510510

511511
return torch.sparse_csr_tensor(rowptr, col, value, self.sizes())
512512

513+
def to_torch_sparse_csc_tensor(
514+
self, dtype: Optional[int] = None) -> torch.Tensor:
515+
colptr, row, value = self.csc()
516+
517+
if value is None:
518+
value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
519+
520+
return torch.sparse_csc_tensor(colptr, row, value, self.sizes())
521+
513522

514523
# Python Bindings #############################################################
515524

0 commit comments

Comments
 (0)