Skip to content

Commit d676bcf

Browse files
authored
[Fix] fix_SAM (#3149) (#3150)
1 parent a783bac commit d676bcf

File tree

5 files changed

+376
-7
lines changed

5 files changed

+376
-7
lines changed

contrib/SegmentAnything/segment_anything/automatic_mask_generator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
166166
crop_box (list(float)): The crop of the image used to generate
167167
the mask, given in XYWH format.
168168
"""
169-
paddle.device.set_device('gpu')
170169
# Generate masks
171170
mask_data = self._generate_masks(image)
172171

contrib/SegmentAnything/segment_anything/modeling/image_encoder.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,6 @@ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
437437
import torch
438438
from segment_anything.modeling import ImageEncoderViT as ImageEncoderViT_torch
439439

440-
# Set random seed
441-
# np.random.seed(42)
442-
paddle.set_device('gpu')
443440
image_encoder_t = ImageEncoderViT_torch(
444441
depth=12,
445442
embed_dim=768,
@@ -454,7 +451,6 @@ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
454451
global_attn_indexes=[2, 5, 8, 11],
455452
window_size=14,
456453
out_chans=256, )
457-
# image_encoder_t = image_encoder_t.to('cuda:1')
458454

459455
image_encoder = ImageEncoderViT(
460456
depth=12,

contrib/SegmentAnything/segment_anything/modeling/sam.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def __init__(
6363

6464
@property
6565
def device(self) -> Any:
66-
return 'gpu'
66+
if paddle.is_compiled_with_cuda():
67+
return 'gpu'
68+
else:
69+
return 'cpu'
6770

6871
@paddle.no_grad()
6972
def forward(

contrib/SegmentAnything/segment_anything/predictor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,10 @@ def get_image_embedding(self) -> paddle.Tensor:
270270

271271
@property
272272
def device(self) -> paddle.device:
273-
return 'gpu'
273+
if paddle.is_compiled_with_cuda():
274+
return 'cpu'
275+
else:
276+
return 'gpu'
274277

275278
def reset_image(self) -> None:
276279
"""Resets the currently set image."""

0 commit comments

Comments
 (0)