Skip to content

Commit 1a0ba22

Browse files
committed
Push mapper data to reducers after execution
1 parent acecc9c commit 1a0ba22

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

mars/services/storage/api/oscar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import sys
16-
from typing import Any, List, Type, TypeVar, Union
16+
from typing import Any, List, Tuple, Type, TypeVar, Union
1717

1818
from .... import oscar as mo
1919
from ....lib.aio import alru_cache
@@ -163,7 +163,7 @@ async def batch_delete(self, args_list, kwargs_list):
163163
@mo.extensible
164164
async def fetch(
165165
self,
166-
data_key: str,
166+
data_key: Union[str, Tuple],
167167
level: StorageLevel = None,
168168
band_name: str = None,
169169
remote_address: str = None,

mars/services/subtask/worker/processor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
Fetch,
2727
FetchShuffle,
2828
execute,
29+
MapReduceOperand,
30+
OperandStage,
2931
)
3032
from ....metrics import Metrics
3133
from ....optimization.physical import optimize
@@ -424,6 +426,28 @@ async def set_chunks_meta():
424426
# set result data size
425427
self.result.data_size = result_data_size
426428

429+
async def push_mapper_data(self, chunk_graph):
430+
# TODO: use task api to get reducer bands
431+
reducer_idx_to_band = dict()
432+
if not reducer_idx_to_band:
433+
return
434+
storage_api_to_fetch_tasks = defaultdict(list)
435+
for result_chunk in chunk_graph.result_chunks:
436+
key = result_chunk.key
437+
reducer_idx = key[1]
438+
if isinstance(key, tuple):
439+
# mapper key is a tuple
440+
address, band_name = reducer_idx_to_band[reducer_idx]
441+
storage_api = StorageAPI(address, self._session_id, band_name)
442+
fetch_task = storage_api.fetch.delay(
443+
key, band_name=self._band[1], remote_address=self._band[0]
444+
)
445+
storage_api_to_fetch_tasks[storage_api].append(fetch_task)
446+
batch_tasks = []
447+
for storage_api, tasks in storage_api_to_fetch_tasks.items():
448+
batch_tasks.append(asyncio.create_task(storage_api.fetch.batch(*tasks)))
449+
await asyncio.gather(*batch_tasks)
450+
427451
async def done(self):
428452
if self.result.status == SubtaskStatus.running:
429453
self.result.status = SubtaskStatus.succeeded
@@ -495,6 +519,8 @@ async def run(self):
495519
await self._unpin_data(input_keys)
496520

497521
await self.done()
522+
# after done, we push mapper data to reducers in advance.
523+
await self.push_mapper_data(chunk_graph)
498524
if self.result.status == SubtaskStatus.succeeded:
499525
cost_time_secs = (
500526
self.result.execution_end_time - self.result.execution_start_time

0 commit comments

Comments
 (0)