1
1
import logging
2
+ from dataclasses import dataclass , field
2
3
from typing import ClassVar
3
4
4
5
from databricks .labs .blueprint .installation import Installation
10
11
logger = logging .getLogger (__name__ )
11
12
12
13
14
+ @dataclass
15
+ class AccountGroupDetails :
16
+ id : str
17
+ members : list [ComplexValue ] | None = None
18
+
19
+
20
+ @dataclass
21
+ class AccountGroupCreationContext :
22
+ valid_workspace_groups : dict [str , Group ] = field (default_factory = dict )
23
+ created_groups : dict [str , Group ] = field (default_factory = dict )
24
+ renamed_groups : dict [str , str ] = field (default_factory = dict )
25
+ preexisting_account_groups : dict [str , AccountGroupDetails ] = field (default_factory = dict )
26
+
27
+
13
28
class AccountWorkspaces :
14
29
SYNC_FILE_NAME : ClassVar [str ] = "workspaces.json"
15
30
@@ -76,21 +91,101 @@ def sync_workspace_info(self, workspaces: list[Workspace] | None = None):
76
91
except (PermissionDenied , NotFound , ValueError ):
77
92
logger .warning (f"Failed to save workspace info for { ws .config .host } " )
78
93
79
- def create_account_level_groups (self , prompts : Prompts ):
80
- acc_groups = self ._get_account_groups ()
94
+ def create_account_level_groups (self , prompts : Prompts ) -> None :
95
+ """
96
+ Create account level groups from workspace groups
97
+
98
+ The following approach is used:
99
+ Get all valid worskpace groups from all workspaces
100
+
101
+ For each group:
102
+ - Check if the group already exists in the account
103
+ - If it does not exist, check if it is a nested group (users are added directly)
104
+ - If its a nested group follow the same approach recursively
105
+ - If it is a regular group, create the group in the account and add all members to the group
106
+ """
107
+ context = AccountGroupCreationContext ()
108
+ context .preexisting_account_groups = self ._get_account_groups ()
81
109
workspace_ids = [workspace .workspace_id for workspace in self ._workspaces ()]
82
110
if not workspace_ids :
83
111
raise ValueError ("The workspace ids provided are not found in the account, Please check and try again." )
84
- all_valid_workspace_groups = self ._get_valid_workspaces_groups (prompts , workspace_ids )
112
+ context . valid_workspace_groups = self ._get_valid_workspaces_groups (prompts , workspace_ids , context )
85
113
86
- for group_name , valid_group in all_valid_workspace_groups .items ():
87
- acc_group = self ._try_create_account_groups (group_name , acc_groups )
114
+ for group_name , valid_group in context . valid_workspace_groups .items ():
115
+ self ._create_account_groups_recursively (group_name , valid_group , context )
88
116
89
- if not acc_group or not valid_group .members or not acc_group .id :
90
- continue
91
- if len (valid_group .members ) > 0 :
92
- self ._add_members_to_acc_group (self ._ac , acc_group .id , group_name , valid_group )
93
- logger .info (f"Group { group_name } created in the account" )
117
+ def _create_account_groups_recursively (
118
+ self , group_name : str , valid_group : Group , context : AccountGroupCreationContext
119
+ ) -> None :
120
+ """
121
+ Function recursively crawls through all group and nested groups to create account level groups
122
+ """
123
+ if group_name in context .created_groups :
124
+ logger .info (f"Group { group_name } already exist in the account, ignoring" )
125
+ return
126
+
127
+ members_to_add = []
128
+ assert valid_group .members is not None , "group members undefined"
129
+ for member in valid_group .members :
130
+ if member .ref and member .ref .startswith ("Users" ):
131
+ members_to_add .append (member )
132
+ elif member .ref and member .ref .startswith ("Groups" ):
133
+ assert member .display is not None , "group name undefined"
134
+ members_to_append = self ._handle_nested_group (member .display , context )
135
+ if members_to_append :
136
+ members_to_add .append (members_to_append )
137
+ else :
138
+ logger .warning (f"Member { member .ref } is not a user or group, skipping" )
139
+
140
+ acc_group = self ._try_create_account_groups (group_name , context .preexisting_account_groups )
141
+ if acc_group :
142
+ assert valid_group .display_name is not None , "group name undefined"
143
+ logger .info (f"Successfully created account group { acc_group .display_name } " )
144
+ if members_to_add and acc_group .id :
145
+ self ._add_members_to_acc_group (self ._ac , acc_group .id , valid_group .display_name , members_to_add )
146
+ created_acc_group = self ._safe_groups_get (self ._ac , acc_group .id )
147
+ if not created_acc_group :
148
+ logger .warning (f"Newly created group { valid_group .display_name } could not be fetched, skipping" )
149
+ return
150
+ context .created_groups [valid_group .display_name ] = created_acc_group
151
+
152
+ def _handle_nested_group (self , group_name : str , context : AccountGroupCreationContext ) -> ComplexValue | None :
153
+ """
154
+ Function to handle nested groups
155
+ Checks if the group has already been created at account level
156
+ If not, it creates the group by calling _create_account_groups_recursively
157
+ """
158
+ # check if group name is in the renamed groups
159
+ if group_name in context .renamed_groups :
160
+ group_name = context .renamed_groups [group_name ]
161
+
162
+ # check if account group was created before this run
163
+ if group_name in context .preexisting_account_groups :
164
+ logger .info (f"Group { group_name } already exist in the account, ignoring" )
165
+ acc_group_id = context .preexisting_account_groups [group_name ].id
166
+ full_account_group = self ._safe_groups_get (self ._ac , acc_group_id )
167
+ if not full_account_group :
168
+ logger .warning (f"Group { group_name } could not be fetched, skipping" )
169
+ return None
170
+ context .created_groups [group_name ] = full_account_group
171
+
172
+ # check if workspace group is already created at account level in current run
173
+ if group_name not in context .created_groups :
174
+ # if there is no account group created for the group, create one
175
+ self ._create_account_groups_recursively (group_name , context .valid_workspace_groups [group_name ], context )
176
+
177
+ if group_name not in context .created_groups :
178
+ logger .warning (f"Group { group_name } could not be fetched, skipping" )
179
+ return None
180
+
181
+ created_acc_group = context .created_groups [group_name ]
182
+
183
+ # the AccountGroupsAPI expects the members to be in the form of ComplexValue
184
+ return ComplexValue (
185
+ display = created_acc_group .display_name ,
186
+ ref = f"Groups/{ created_acc_group .id } " ,
187
+ value = created_acc_group .id ,
188
+ )
94
189
95
190
def get_accessible_workspaces (self ) -> list [Workspace ]:
96
191
"""
@@ -126,9 +221,7 @@ def can_administer(self, workspace: Workspace) -> bool:
126
221
return False
127
222
return True
128
223
129
- def _try_create_account_groups (
130
- self , group_name : str , acc_groups : dict [str | None , list [ComplexValue ] | None ]
131
- ) -> Group | None :
224
+ def _try_create_account_groups (self , group_name : str , acc_groups : dict [str , AccountGroupDetails ]) -> Group | None :
132
225
try :
133
226
if group_name in acc_groups :
134
227
logger .info (f"Group { group_name } already exist in the account, ignoring" )
@@ -139,9 +232,9 @@ def _try_create_account_groups(
139
232
return None
140
233
141
234
def _add_members_to_acc_group (
142
- self , acc_client : AccountClient , acc_group_id : str , group_name : str , valid_group : Group
235
+ self , acc_client : AccountClient , acc_group_id : str , group_name : str , group_members : list [ ComplexValue ] | None
143
236
):
144
- for chunk in self ._chunks (valid_group . members , 20 ):
237
+ for chunk in self ._chunks (group_members , 20 ):
145
238
logger .debug (f"Adding { len (chunk )} members to acc group { group_name } " )
146
239
acc_client .groups .patch (
147
240
acc_group_id ,
@@ -155,17 +248,25 @@ def _chunks(lst, chunk_size):
155
248
for i in range (0 , len (lst ), chunk_size ):
156
249
yield lst [i : i + chunk_size ]
157
250
158
- def _get_valid_workspaces_groups (self , prompts : Prompts , workspace_ids : list [int ]) -> dict [str , Group ]:
251
+ def _get_valid_workspaces_groups (
252
+ self , prompts : Prompts , workspace_ids : list [int ], context : AccountGroupCreationContext
253
+ ) -> dict [str , Group ]:
159
254
all_workspaces_groups : dict [str , Group ] = {}
160
255
161
256
for workspace in self ._workspaces ():
162
257
if workspace .workspace_id not in workspace_ids :
163
258
continue
164
- self ._load_workspace_groups (prompts , workspace , all_workspaces_groups )
259
+ self ._load_workspace_groups (prompts , workspace , all_workspaces_groups , context )
165
260
166
261
return all_workspaces_groups
167
262
168
- def _load_workspace_groups (self , prompts , workspace , all_workspaces_groups ):
263
+ def _load_workspace_groups (
264
+ self ,
265
+ prompts : Prompts ,
266
+ workspace : Workspace ,
267
+ all_workspaces_groups : dict [str , Group ],
268
+ context : AccountGroupCreationContext ,
269
+ ) -> None :
169
270
client = self .client_for (workspace )
170
271
logger .info (f"Crawling groups in workspace { client .config .host } " )
171
272
ws_group_ids = client .groups .list (attributes = "id" )
@@ -188,6 +289,7 @@ def _load_workspace_groups(self, prompts, workspace, all_workspaces_groups):
188
289
f"it will be created at the account with name : { workspace .workspace_name } _{ group_name } "
189
290
):
190
291
all_workspaces_groups [f"{ workspace .workspace_name } _{ group_name } " ] = full_workspace_group
292
+ context .renamed_groups [group_name ] = f"{ workspace .workspace_name } _{ group_name } "
191
293
continue
192
294
logger .info (f"Found new group { group_name } " )
193
295
all_workspaces_groups [group_name ] = full_workspace_group
@@ -212,7 +314,7 @@ def _has_same_members(group_1: Group, group_2: Group) -> bool:
212
314
ws_members_set_2 = set ([m .display for m in group_2 .members ] if group_2 .members else [])
213
315
return not bool ((ws_members_set_1 - ws_members_set_2 ).union (ws_members_set_2 - ws_members_set_1 ))
214
316
215
- def _get_account_groups (self ) -> dict [str | None , list [ ComplexValue ] | None ]:
317
+ def _get_account_groups (self ) -> dict [str , AccountGroupDetails ]:
216
318
logger .debug ("Listing groups in account" )
217
319
acc_groups = {}
218
320
for acc_grp_id in self ._ac .groups .list (attributes = "id" ):
@@ -222,7 +324,10 @@ def _get_account_groups(self) -> dict[str | None, list[ComplexValue] | None]:
222
324
if not full_account_group :
223
325
continue
224
326
logger .debug (f"Found account group { full_account_group .display_name } " )
225
- acc_groups [full_account_group .display_name ] = full_account_group .members
327
+ assert full_account_group .display_name is not None , "group name undefined"
328
+ acc_groups [full_account_group .display_name ] = AccountGroupDetails (
329
+ id = acc_grp_id .id , members = full_account_group .members
330
+ )
226
331
227
332
logger .info (f"{ len (acc_groups )} account groups found" )
228
333
return acc_groups
0 commit comments