Skip to content

Commit 8a6a887

Browse files
fix(ui): ensure staging area always has the right state and session association
1 parent 12d9862 commit 8a6a887

File tree

3 files changed

+55
-37
lines changed

3 files changed

+55
-37
lines changed

invokeai/frontend/web/src/features/controlLayers/components/StagingArea/context.tsx

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import { selectBboxRect, selectSelectedEntityIdentifier } from 'features/control
1515
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
1616
import { imageNameToImageObject } from 'features/controlLayers/store/util';
1717
import type { PropsWithChildren } from 'react';
18-
import { createContext, memo, useContext, useEffect, useMemo } from 'react';
18+
import { createContext, memo, useContext, useEffect, useMemo, useState } from 'react';
1919
import { getImageDTOSafe } from 'services/api/endpoints/images';
2020
import { queueApi } from 'services/api/endpoints/queue';
2121
import type { S } from 'services/api/types';
@@ -94,18 +94,24 @@ export const StagingAreaContextProvider = memo(({ children, sessionId }: PropsWi
9494

9595
return _stagingAreaAppApi;
9696
}, [sessionId, socket, store]);
97-
const value = useMemo(() => {
98-
return new StagingAreaApi(sessionId, stagingAreaAppApi);
99-
}, [sessionId, stagingAreaAppApi]);
97+
98+
const [stagingAreaApi] = useState(() => new StagingAreaApi());
10099

101100
useEffect(() => {
102-
const api = value;
101+
stagingAreaApi.connectToApp(sessionId, stagingAreaAppApi);
102+
103+
// We need to subscribe to the queue items query manually to ensure the staging area actually gets the items
104+
const { unsubscribe: unsubQueueItemsQuery } = store.dispatch(
105+
queueApi.endpoints.listAllQueueItems.initiate({ destination: sessionId })
106+
);
107+
103108
return () => {
104-
api.cleanup();
109+
stagingAreaApi.cleanup();
110+
unsubQueueItemsQuery();
105111
};
106-
}, [value]);
112+
}, [sessionId, stagingAreaApi, stagingAreaAppApi, store]);
107113

108-
return <StagingAreaContext.Provider value={value}>{children}</StagingAreaContext.Provider>;
114+
return <StagingAreaContext.Provider value={stagingAreaApi}>{children}</StagingAreaContext.Provider>;
109115
});
110116
StagingAreaContextProvider.displayName = 'StagingAreaContextProvider';
111117

invokeai/frontend/web/src/features/controlLayers/components/StagingArea/state.test.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ describe('StagingAreaApi', () => {
1616

1717
beforeEach(() => {
1818
mockApp = createMockStagingAreaApp();
19-
api = new StagingAreaApi(sessionId, mockApp);
19+
api = new StagingAreaApi();
20+
api.connectToApp(sessionId, mockApp);
2021
});
2122

2223
afterEach(() => {
@@ -25,7 +26,7 @@ describe('StagingAreaApi', () => {
2526

2627
describe('Constructor and Setup', () => {
2728
it('should initialize with correct session ID', () => {
28-
expect(api.sessionId).toBe(sessionId);
29+
expect(api._sessionId).toBe(sessionId);
2930
});
3031

3132
it('should set up event subscriptions', () => {
@@ -747,8 +748,10 @@ describe('StagingAreaApi', () => {
747748

748749
describe('Event Subscription Management', () => {
749750
it('should handle multiple subscriptions and unsubscriptions', () => {
750-
const api2 = new StagingAreaApi(sessionId, mockApp);
751-
const api3 = new StagingAreaApi(sessionId, mockApp);
751+
const api2 = new StagingAreaApi();
752+
api2.connectToApp(sessionId, mockApp);
753+
const api3 = new StagingAreaApi();
754+
api3.connectToApp(sessionId, mockApp);
752755

753756
// All should be subscribed
754757
expect(mockApp.onItemsChanged).toHaveBeenCalledTimes(3);

invokeai/frontend/web/src/features/controlLayers/components/StagingArea/state.ts

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,14 @@ type ProgressDataMap = Record<number, ProgressData | undefined>;
6161
* and configure auto-switching behavior.
6262
*/
6363
export class StagingAreaApi {
64-
sessionId: string;
65-
_app: StagingAreaAppApi;
66-
_subscriptions = new Set<() => void>();
64+
/** The current session ID. */
65+
_sessionId: string | null = null;
6766

68-
constructor(sessionId: string, app: StagingAreaAppApi) {
69-
this.sessionId = sessionId;
70-
this._app = app;
67+
/** The app API */
68+
_app: StagingAreaAppApi | null = null;
7169

72-
this._subscriptions.add(this._app.onItemsChanged(this.onItemsChangedEvent));
73-
this._subscriptions.add(this._app.onQueueItemStatusChanged(this.onQueueItemStatusChangedEvent));
74-
this._subscriptions.add(this._app.onInvocationProgress(this.onInvocationProgressEvent));
75-
}
70+
/** A set of subscriptions to be cleaned up when we are finished with a session */
71+
_subscriptions = new Set<() => void>();
7672

7773
/** Item ID of the last started item. Used for auto-switch on start. */
7874
$lastStartedItemId = atom<number | null>(null);
@@ -136,7 +132,7 @@ export class StagingAreaApi {
136132
/** Selects a queue item by ID. */
137133
select = (itemId: number) => {
138134
this.$selectedItemId.set(itemId);
139-
this._app.onSelect?.(itemId);
135+
this._app?.onSelect?.(itemId);
140136
};
141137

142138
/** Selects the next item in the queue, wrapping to the first item if at the end. */
@@ -152,7 +148,7 @@ export class StagingAreaApi {
152148
return;
153149
}
154150
this.$selectedItemId.set(nextItem.item_id);
155-
this._app.onSelectNext?.();
151+
this._app?.onSelectNext?.();
156152
};
157153

158154
/** Selects the previous item in the queue, wrapping to the last item if at the beginning. */
@@ -168,7 +164,7 @@ export class StagingAreaApi {
168164
return;
169165
}
170166
this.$selectedItemId.set(prevItem.item_id);
171-
this._app.onSelectPrev?.();
167+
this._app?.onSelectPrev?.();
172168
};
173169

174170
/** Selects the first item in the queue. */
@@ -179,7 +175,7 @@ export class StagingAreaApi {
179175
return;
180176
}
181177
this.$selectedItemId.set(first.item_id);
182-
this._app.onSelectFirst?.();
178+
this._app?.onSelectFirst?.();
183179
};
184180

185181
/** Selects the last item in the queue. */
@@ -190,7 +186,7 @@ export class StagingAreaApi {
190186
return;
191187
}
192188
this.$selectedItemId.set(last.item_id);
193-
this._app.onSelectLast?.();
189+
this._app?.onSelectLast?.();
194190
};
195191

196192
/** Discards the currently selected item and selects the next available item. */
@@ -207,7 +203,7 @@ export class StagingAreaApi {
207203
} else {
208204
this.$selectedItemId.set(null);
209205
}
210-
this._app.onDiscard?.(selectedItem.item);
206+
this._app?.onDiscard?.(selectedItem.item);
211207
};
212208

213209
/** Whether the discard selected action is enabled. */
@@ -218,10 +214,23 @@ export class StagingAreaApi {
218214
return true;
219215
});
220216

217+
/** Connects to the app, registering listeners and such */
218+
connectToApp = (sessionId: string, app: StagingAreaAppApi) => {
219+
if (this._sessionId !== sessionId) {
220+
this.cleanup();
221+
this._sessionId = sessionId;
222+
}
223+
this._app = app;
224+
225+
this._subscriptions.add(this._app.onItemsChanged(this.onItemsChangedEvent));
226+
this._subscriptions.add(this._app.onQueueItemStatusChanged(this.onQueueItemStatusChangedEvent));
227+
this._subscriptions.add(this._app.onInvocationProgress(this.onInvocationProgressEvent));
228+
};
229+
221230
/** Discards all items in the queue. */
222231
discardAll = () => {
223232
this.$selectedItemId.set(null);
224-
this._app.onDiscardAll?.();
233+
this._app?.onDiscardAll?.();
225234
};
226235

227236
/** Accepts the currently selected item if an image is available. */
@@ -235,7 +244,7 @@ export class StagingAreaApi {
235244
if (!datum || !datum.imageDTO) {
236245
return;
237246
}
238-
this._app.onAccept?.(selectedItem.item, datum.imageDTO);
247+
this._app?.onAccept?.(selectedItem.item, datum.imageDTO);
239248
};
240249

241250
/** Whether the accept selected action is enabled. */
@@ -249,20 +258,20 @@ export class StagingAreaApi {
249258

250259
/** Sets the auto-switch mode. */
251260
setAutoSwitch = (mode: AutoSwitchMode) => {
252-
this._app.onAutoSwitchChange?.(mode);
261+
this._app?.onAutoSwitchChange?.(mode);
253262
};
254263

255264
/** Handles invocation progress events from the WebSocket. */
256265
onInvocationProgressEvent = (data: S['InvocationProgressEvent']) => {
257-
if (data.destination !== this.sessionId) {
266+
if (data.destination !== this._sessionId) {
258267
return;
259268
}
260269
setProgress(this.$progressData, data);
261270
};
262271

263272
/** Handles queue item status change events from the WebSocket. */
264273
onQueueItemStatusChangedEvent = (data: S['QueueItemStatusChangedEvent']) => {
265-
if (data.destination !== this.sessionId) {
274+
if (data.destination !== this._sessionId) {
266275
return;
267276
}
268277
if (data.status === 'completed') {
@@ -277,7 +286,7 @@ export class StagingAreaApi {
277286
*/
278287
this.$lastCompletedItemId.set(data.item_id);
279288
}
280-
if (data.status === 'in_progress' && this._app.getAutoSwitch() === 'switch_on_start') {
289+
if (data.status === 'in_progress' && this._app?.getAutoSwitch() === 'switch_on_start') {
281290
this.$lastStartedItemId.set(data.item_id);
282291
}
283292
};
@@ -327,7 +336,7 @@ export class StagingAreaApi {
327336
for (const item of items) {
328337
const datum = progressData[item.item_id];
329338

330-
if (this.$lastStartedItemId.get() === item.item_id && this._app.getAutoSwitch() === 'switch_on_start') {
339+
if (this.$lastStartedItemId.get() === item.item_id && this._app?.getAutoSwitch() === 'switch_on_start') {
331340
this.$selectedItemId.set(item.item_id);
332341
this.$lastStartedItemId.set(null);
333342
}
@@ -339,13 +348,13 @@ export class StagingAreaApi {
339348
if (!outputImageName) {
340349
continue;
341350
}
342-
const imageDTO = await this._app.getImageDTO(outputImageName);
351+
const imageDTO = await this._app?.getImageDTO(outputImageName);
343352
if (!imageDTO) {
344353
continue;
345354
}
346355

347356
// This is the load logic mentioned in the comment in the QueueItemStatusChangedEvent handler above.
348-
if (this.$lastCompletedItemId.get() === item.item_id && this._app.getAutoSwitch() === 'switch_on_finish') {
357+
if (this.$lastCompletedItemId.get() === item.item_id && this._app?.getAutoSwitch() === 'switch_on_finish') {
349358
this._app.loadImage(imageDTO.image_url).then(() => {
350359
this.$selectedItemId.set(item.item_id);
351360
this.$lastCompletedItemId.set(null);

0 commit comments

Comments
 (0)