Skip to content

Commit 63c64d3

Browse files
pritishpaiFastLee
andauthored
Added ability to create account groups from nested ws-local groups (#3818)
## Changes Account level nested groups created using create-account-groups from their workspace-local groups ### Linked issues Resolves #3796 ### Functionality - [ ] added relevant user documentation - [ ] added new CLI command - [ ] modified existing command: `databricks labs ucx ...` - [ ] added a new workflow - [ ] modified existing workflow: `...` - [ ] added a new table - [ ] modified existing table: `...` ### Tests - [ ] manually tested - [ ] added unit tests - [x] added integration tests - [ ] verified on staging environment (screenshot attached) --------- Co-authored-by: Liran Bareket <liran.bareket@databricks.com>
1 parent 10a5c44 commit 63c64d3

File tree

3 files changed

+262
-80
lines changed

3 files changed

+262
-80
lines changed

src/databricks/labs/ucx/account/workspaces.py

Lines changed: 125 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from dataclasses import dataclass, field
23
from typing import ClassVar
34

45
from databricks.labs.blueprint.installation import Installation
@@ -10,6 +11,20 @@
1011
logger = logging.getLogger(__name__)
1112

1213

14+
@dataclass
15+
class AccountGroupDetails:
16+
id: str
17+
members: list[ComplexValue] | None = None
18+
19+
20+
@dataclass
21+
class AccountGroupCreationContext:
22+
valid_workspace_groups: dict[str, Group] = field(default_factory=dict)
23+
created_groups: dict[str, Group] = field(default_factory=dict)
24+
renamed_groups: dict[str, str] = field(default_factory=dict)
25+
preexisting_account_groups: dict[str, AccountGroupDetails] = field(default_factory=dict)
26+
27+
1328
class AccountWorkspaces:
1429
SYNC_FILE_NAME: ClassVar[str] = "workspaces.json"
1530

@@ -76,21 +91,101 @@ def sync_workspace_info(self, workspaces: list[Workspace] | None = None):
7691
except (PermissionDenied, NotFound, ValueError):
7792
logger.warning(f"Failed to save workspace info for {ws.config.host}")
7893

79-
def create_account_level_groups(self, prompts: Prompts):
80-
acc_groups = self._get_account_groups()
94+
def create_account_level_groups(self, prompts: Prompts) -> None:
95+
"""
96+
Create account level groups from workspace groups
97+
98+
The following approach is used:
99+
Get all valid worskpace groups from all workspaces
100+
101+
For each group:
102+
- Check if the group already exists in the account
103+
- If it does not exist, check if it is a nested group (users are added directly)
104+
- If its a nested group follow the same approach recursively
105+
- If it is a regular group, create the group in the account and add all members to the group
106+
"""
107+
context = AccountGroupCreationContext()
108+
context.preexisting_account_groups = self._get_account_groups()
81109
workspace_ids = [workspace.workspace_id for workspace in self._workspaces()]
82110
if not workspace_ids:
83111
raise ValueError("The workspace ids provided are not found in the account, Please check and try again.")
84-
all_valid_workspace_groups = self._get_valid_workspaces_groups(prompts, workspace_ids)
112+
context.valid_workspace_groups = self._get_valid_workspaces_groups(prompts, workspace_ids, context)
85113

86-
for group_name, valid_group in all_valid_workspace_groups.items():
87-
acc_group = self._try_create_account_groups(group_name, acc_groups)
114+
for group_name, valid_group in context.valid_workspace_groups.items():
115+
self._create_account_groups_recursively(group_name, valid_group, context)
88116

89-
if not acc_group or not valid_group.members or not acc_group.id:
90-
continue
91-
if len(valid_group.members) > 0:
92-
self._add_members_to_acc_group(self._ac, acc_group.id, group_name, valid_group)
93-
logger.info(f"Group {group_name} created in the account")
117+
def _create_account_groups_recursively(
118+
self, group_name: str, valid_group: Group, context: AccountGroupCreationContext
119+
) -> None:
120+
"""
121+
Function recursively crawls through all group and nested groups to create account level groups
122+
"""
123+
if group_name in context.created_groups:
124+
logger.info(f"Group {group_name} already exist in the account, ignoring")
125+
return
126+
127+
members_to_add = []
128+
assert valid_group.members is not None, "group members undefined"
129+
for member in valid_group.members:
130+
if member.ref and member.ref.startswith("Users"):
131+
members_to_add.append(member)
132+
elif member.ref and member.ref.startswith("Groups"):
133+
assert member.display is not None, "group name undefined"
134+
members_to_append = self._handle_nested_group(member.display, context)
135+
if members_to_append:
136+
members_to_add.append(members_to_append)
137+
else:
138+
logger.warning(f"Member {member.ref} is not a user or group, skipping")
139+
140+
acc_group = self._try_create_account_groups(group_name, context.preexisting_account_groups)
141+
if acc_group:
142+
assert valid_group.display_name is not None, "group name undefined"
143+
logger.info(f"Successfully created account group {acc_group.display_name}")
144+
if members_to_add and acc_group.id:
145+
self._add_members_to_acc_group(self._ac, acc_group.id, valid_group.display_name, members_to_add)
146+
created_acc_group = self._safe_groups_get(self._ac, acc_group.id)
147+
if not created_acc_group:
148+
logger.warning(f"Newly created group {valid_group.display_name} could not be fetched, skipping")
149+
return
150+
context.created_groups[valid_group.display_name] = created_acc_group
151+
152+
def _handle_nested_group(self, group_name: str, context: AccountGroupCreationContext) -> ComplexValue | None:
153+
"""
154+
Function to handle nested groups
155+
Checks if the group has already been created at account level
156+
If not, it creates the group by calling _create_account_groups_recursively
157+
"""
158+
# check if group name is in the renamed groups
159+
if group_name in context.renamed_groups:
160+
group_name = context.renamed_groups[group_name]
161+
162+
# check if account group was created before this run
163+
if group_name in context.preexisting_account_groups:
164+
logger.info(f"Group {group_name} already exist in the account, ignoring")
165+
acc_group_id = context.preexisting_account_groups[group_name].id
166+
full_account_group = self._safe_groups_get(self._ac, acc_group_id)
167+
if not full_account_group:
168+
logger.warning(f"Group {group_name} could not be fetched, skipping")
169+
return None
170+
context.created_groups[group_name] = full_account_group
171+
172+
# check if workspace group is already created at account level in current run
173+
if group_name not in context.created_groups:
174+
# if there is no account group created for the group, create one
175+
self._create_account_groups_recursively(group_name, context.valid_workspace_groups[group_name], context)
176+
177+
if group_name not in context.created_groups:
178+
logger.warning(f"Group {group_name} could not be fetched, skipping")
179+
return None
180+
181+
created_acc_group = context.created_groups[group_name]
182+
183+
# the AccountGroupsAPI expects the members to be in the form of ComplexValue
184+
return ComplexValue(
185+
display=created_acc_group.display_name,
186+
ref=f"Groups/{created_acc_group.id}",
187+
value=created_acc_group.id,
188+
)
94189

95190
def get_accessible_workspaces(self) -> list[Workspace]:
96191
"""
@@ -126,9 +221,7 @@ def can_administer(self, workspace: Workspace) -> bool:
126221
return False
127222
return True
128223

129-
def _try_create_account_groups(
130-
self, group_name: str, acc_groups: dict[str | None, list[ComplexValue] | None]
131-
) -> Group | None:
224+
def _try_create_account_groups(self, group_name: str, acc_groups: dict[str, AccountGroupDetails]) -> Group | None:
132225
try:
133226
if group_name in acc_groups:
134227
logger.info(f"Group {group_name} already exist in the account, ignoring")
@@ -139,9 +232,9 @@ def _try_create_account_groups(
139232
return None
140233

141234
def _add_members_to_acc_group(
142-
self, acc_client: AccountClient, acc_group_id: str, group_name: str, valid_group: Group
235+
self, acc_client: AccountClient, acc_group_id: str, group_name: str, group_members: list[ComplexValue] | None
143236
):
144-
for chunk in self._chunks(valid_group.members, 20):
237+
for chunk in self._chunks(group_members, 20):
145238
logger.debug(f"Adding {len(chunk)} members to acc group {group_name}")
146239
acc_client.groups.patch(
147240
acc_group_id,
@@ -155,17 +248,25 @@ def _chunks(lst, chunk_size):
155248
for i in range(0, len(lst), chunk_size):
156249
yield lst[i : i + chunk_size]
157250

158-
def _get_valid_workspaces_groups(self, prompts: Prompts, workspace_ids: list[int]) -> dict[str, Group]:
251+
def _get_valid_workspaces_groups(
252+
self, prompts: Prompts, workspace_ids: list[int], context: AccountGroupCreationContext
253+
) -> dict[str, Group]:
159254
all_workspaces_groups: dict[str, Group] = {}
160255

161256
for workspace in self._workspaces():
162257
if workspace.workspace_id not in workspace_ids:
163258
continue
164-
self._load_workspace_groups(prompts, workspace, all_workspaces_groups)
259+
self._load_workspace_groups(prompts, workspace, all_workspaces_groups, context)
165260

166261
return all_workspaces_groups
167262

168-
def _load_workspace_groups(self, prompts, workspace, all_workspaces_groups):
263+
def _load_workspace_groups(
264+
self,
265+
prompts: Prompts,
266+
workspace: Workspace,
267+
all_workspaces_groups: dict[str, Group],
268+
context: AccountGroupCreationContext,
269+
) -> None:
169270
client = self.client_for(workspace)
170271
logger.info(f"Crawling groups in workspace {client.config.host}")
171272
ws_group_ids = client.groups.list(attributes="id")
@@ -188,6 +289,7 @@ def _load_workspace_groups(self, prompts, workspace, all_workspaces_groups):
188289
f"it will be created at the account with name : {workspace.workspace_name}_{group_name}"
189290
):
190291
all_workspaces_groups[f"{workspace.workspace_name}_{group_name}"] = full_workspace_group
292+
context.renamed_groups[group_name] = f"{workspace.workspace_name}_{group_name}"
191293
continue
192294
logger.info(f"Found new group {group_name}")
193295
all_workspaces_groups[group_name] = full_workspace_group
@@ -212,7 +314,7 @@ def _has_same_members(group_1: Group, group_2: Group) -> bool:
212314
ws_members_set_2 = set([m.display for m in group_2.members] if group_2.members else [])
213315
return not bool((ws_members_set_1 - ws_members_set_2).union(ws_members_set_2 - ws_members_set_1))
214316

215-
def _get_account_groups(self) -> dict[str | None, list[ComplexValue] | None]:
317+
def _get_account_groups(self) -> dict[str, AccountGroupDetails]:
216318
logger.debug("Listing groups in account")
217319
acc_groups = {}
218320
for acc_grp_id in self._ac.groups.list(attributes="id"):
@@ -222,7 +324,10 @@ def _get_account_groups(self) -> dict[str | None, list[ComplexValue] | None]:
222324
if not full_account_group:
223325
continue
224326
logger.debug(f"Found account group {full_account_group.display_name}")
225-
acc_groups[full_account_group.display_name] = full_account_group.members
327+
assert full_account_group.display_name is not None, "group name undefined"
328+
acc_groups[full_account_group.display_name] = AccountGroupDetails(
329+
id=acc_grp_id.id, members=full_account_group.members
330+
)
226331

227332
logger.info(f"{len(acc_groups)} account groups found")
228333
return acc_groups

tests/integration/account/test_account.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from databricks.labs.blueprint.tui import MockPrompts
55
from databricks.sdk import AccountClient
66
from databricks.sdk.retries import retried
7-
from databricks.sdk.service.iam import Group
7+
from databricks.sdk.service.iam import Group, User
88

99
from databricks.labs.ucx.account.workspaces import AccountWorkspaces
1010

@@ -64,3 +64,51 @@ def get_group(display_name: str) -> Group:
6464

6565
group = get_group(group_display_name)
6666
assert group
67+
68+
69+
def test_create_account_level_groups_nested_groups(
70+
make_group, make_user, acc, ws, make_random, clean_account_level_groups, watchdog_purge_suffix, runtime_ctx, caplog
71+
):
72+
suffix = f"{make_random(4).lower()}-{watchdog_purge_suffix}"
73+
# Test groups:
74+
# 1. group a contains group_b and group_c.
75+
# 2. group b contains user1, user2 and group_d.
76+
# 3. group c contains group_d.
77+
# 4. group d contains user3 and user4.
78+
79+
users = list[User]()
80+
for _ in range(4):
81+
users.append(make_user())
82+
83+
ws_groups = list[Group]()
84+
ws_groups.append(
85+
make_group(display_name=f"created_by_ucx_regular_group_d-{suffix}", members=[users[2].id, users[3].id])
86+
)
87+
ws_groups.append(make_group(display_name=f"created_by_ucx_regular_group_c-{suffix}", members=ws_groups[0].id))
88+
ws_groups.append(
89+
make_group(
90+
display_name=f"created_by_ucx_regular_group_b-{suffix}", members=[users[0].id, users[1].id, ws_groups[0].id]
91+
)
92+
)
93+
ws_groups.append(
94+
make_group(display_name=f"created_by_ucx_regular_group_a-{suffix}", members=[ws_groups[1].id, ws_groups[2].id])
95+
)
96+
97+
AccountWorkspaces(acc, [ws.get_workspace_id()]).create_account_level_groups(MockPrompts({}))
98+
99+
@retried(on=[KeyError], timeout=timedelta(minutes=2))
100+
def get_group(display_name: str) -> Group:
101+
for grp in acc.groups.list():
102+
if grp.display_name == display_name:
103+
return grp
104+
raise KeyError(f"Group not found {display_name}")
105+
106+
for ws_group in ws_groups:
107+
group_display_name = ws_group.display_name
108+
group = get_group(group_display_name)
109+
assert group
110+
assert len(group.members) == len(ws_group.members)
111+
112+
runtime_ctx.group_manager.validate_group_membership()
113+
114+
assert 'There are no groups with different membership between account and workspace.' in caplog.text

0 commit comments

Comments
 (0)