@@ -179,7 +179,15 @@ def retrieve_online_documents(
179
179
embedding : Optional [List [float ]],
180
180
top_k : int ,
181
181
distance_metric : Optional [str ] = "L2" ,
182
- ) -> List [Tuple [Optional [datetime ], Optional [EntityKeyProto ], Optional [ValueProto ], Optional [ValueProto ], Optional [ValueProto ]]]:
182
+ ) -> List [
183
+ Tuple [
184
+ Optional [datetime ],
185
+ Optional [EntityKeyProto ],
186
+ Optional [ValueProto ],
187
+ Optional [ValueProto ],
188
+ Optional [ValueProto ],
189
+ ]
190
+ ]:
183
191
assert isinstance (config .online_store , RemoteOnlineStoreConfig )
184
192
config .online_store .__class__ = RemoteOnlineStoreConfig
185
193
@@ -190,18 +198,16 @@ def retrieve_online_documents(
190
198
if response .status_code == 200 :
191
199
logger .debug ("Able to retrieve the online documents from feature server." )
192
200
response_json = json .loads (response .text )
193
- event_ts = self ._get_event_ts (response_json )
201
+ event_ts : Optional [ datetime ] = self ._get_event_ts (response_json )
194
202
195
203
# Create feature name to index mapping for efficient lookup
196
204
feature_name_to_index = {
197
- name : idx for idx , name in enumerate (response_json ["metadata" ]["feature_names" ])
205
+ name : idx
206
+ for idx , name in enumerate (response_json ["metadata" ]["feature_names" ])
198
207
}
199
208
200
209
vector_field_metadata = _get_feature_view_vector_field_metadata (table )
201
210
202
- # Extract feature names once
203
- feature_names = response_json ["metadata" ]["feature_names" ]
204
-
205
211
# Process each result row
206
212
num_results = len (response_json ["results" ][0 ]["values" ])
207
213
result_tuples = []
@@ -215,13 +221,21 @@ def retrieve_online_documents(
215
221
response_json , feature_name_to_index , vector_field_metadata , row_idx
216
222
)
217
223
distance_val = self ._extract_distance_value (
218
- response_json , feature_name_to_index , ' distance' , row_idx
224
+ response_json , feature_name_to_index , " distance" , row_idx
219
225
)
220
226
entity_key_proto = self ._construct_entity_key_from_response (
221
- response_json , row_idx , feature_name_to_index
227
+ response_json , row_idx , feature_name_to_index , table
222
228
)
223
229
224
- result_tuples .append ((event_ts , entity_key_proto , feature_val , vector_value , distance_val ))
230
+ result_tuples .append (
231
+ (
232
+ event_ts ,
233
+ entity_key_proto ,
234
+ feature_val ,
235
+ vector_value ,
236
+ distance_val ,
237
+ )
238
+ )
225
239
226
240
return result_tuples
227
241
else :
@@ -238,50 +252,77 @@ def retrieve_online_documents_v2(
238
252
top_k : int ,
239
253
distance_metric : Optional [str ] = None ,
240
254
query_string : Optional [str ] = None ,
241
- ) -> List [Tuple [Optional [datetime ], Optional [EntityKeyProto ], Optional [Dict [str , ValueProto ]]]]:
255
+ ) -> List [
256
+ Tuple [
257
+ Optional [datetime ],
258
+ Optional [EntityKeyProto ],
259
+ Optional [Dict [str , ValueProto ]],
260
+ ]
261
+ ]:
242
262
assert isinstance (config .online_store , RemoteOnlineStoreConfig )
243
263
config .online_store .__class__ = RemoteOnlineStoreConfig
244
264
245
265
req_body = self ._construct_online_documents_v2_api_json_request (
246
- table , requested_features , embedding , top_k , distance_metric , query_string , api_version = 2
266
+ table ,
267
+ requested_features ,
268
+ embedding ,
269
+ top_k ,
270
+ distance_metric ,
271
+ query_string ,
272
+ api_version = 2 ,
247
273
)
248
274
response = get_remote_online_documents (config = config , req_body = req_body )
249
275
if response .status_code == 200 :
250
276
logger .debug ("Able to retrieve the online documents from feature server." )
251
277
response_json = json .loads (response .text )
252
- event_ts = self ._get_event_ts (response_json )
253
-
278
+ event_ts : Optional [ datetime ] = self ._get_event_ts (response_json )
279
+
254
280
# Create feature name to index mapping for efficient lookup
255
281
feature_name_to_index = {
256
- name : idx for idx , name in enumerate (response_json ["metadata" ]["feature_names" ])
282
+ name : idx
283
+ for idx , name in enumerate (response_json ["metadata" ]["feature_names" ])
257
284
}
258
285
259
286
# Process each result row
260
- num_results = len (response_json ["results" ][0 ]["values" ]) if response_json ["results" ] else 0
287
+ num_results = (
288
+ len (response_json ["results" ][0 ]["values" ])
289
+ if response_json ["results" ]
290
+ else 0
291
+ )
261
292
result_tuples = []
262
293
263
294
for row_idx in range (num_results ):
264
295
# Build feature values dictionary for requested features
265
- feature_values_dict : Dict [ str , ValueProto ] = {}
266
-
296
+ feature_values_dict = {}
297
+
267
298
if requested_features :
268
299
for feature_name in requested_features :
269
300
if feature_name in feature_name_to_index :
270
301
feature_idx = feature_name_to_index [feature_name ]
271
- if self ._is_feature_present (response_json , feature_idx , row_idx ):
272
- feature_values_dict [feature_name ] = self ._extract_feature_value (
273
- response_json , feature_idx , row_idx
302
+ if self ._is_feature_present (
303
+ response_json , feature_idx , row_idx
304
+ ):
305
+ feature_values_dict [feature_name ] = (
306
+ self ._extract_feature_value (
307
+ response_json , feature_idx , row_idx
308
+ )
274
309
)
275
310
else :
276
311
feature_values_dict [feature_name ] = ValueProto ()
277
312
278
313
# Construct entity key proto using existing helper method
279
314
entity_key_proto = self ._construct_entity_key_from_response (
280
- response_json , row_idx , feature_name_to_index
315
+ response_json , row_idx , feature_name_to_index , table
316
+ )
317
+
318
+ result_tuples .append (
319
+ (
320
+ event_ts ,
321
+ entity_key_proto ,
322
+ feature_values_dict if feature_values_dict else None ,
323
+ )
281
324
)
282
325
283
- result_tuples .append ((event_ts , entity_key_proto , feature_values_dict ))
284
-
285
326
return result_tuples
286
327
else :
287
328
error_msg = f"Unable to retrieve the online documents using feature server API. Error_code={ response .status_code } , error_message={ response .text } "
@@ -293,8 +334,8 @@ def _extract_requested_feature_value(
293
334
response_json : dict ,
294
335
feature_name_to_index : dict ,
295
336
requested_features : Optional [List [str ]],
296
- row_idx : int
297
- ) -> ValueProto :
337
+ row_idx : int ,
338
+ ) -> Optional [ ValueProto ] :
298
339
"""Extract the first available requested feature value."""
299
340
if not requested_features :
300
341
return ValueProto ()
@@ -303,7 +344,9 @@ def _extract_requested_feature_value(
303
344
if feature_name in feature_name_to_index :
304
345
feature_idx = feature_name_to_index [feature_name ]
305
346
if self ._is_feature_present (response_json , feature_idx , row_idx ):
306
- return self ._extract_feature_value (response_json , feature_idx , row_idx )
347
+ return self ._extract_feature_value (
348
+ response_json , feature_idx , row_idx
349
+ )
307
350
308
351
return ValueProto ()
309
352
@@ -312,15 +355,20 @@ def _extract_vector_field_value(
312
355
response_json : dict ,
313
356
feature_name_to_index : dict ,
314
357
vector_field_metadata ,
315
- row_idx : int
316
- ) -> ValueProto :
358
+ row_idx : int ,
359
+ ) -> Optional [ ValueProto ] :
317
360
"""Extract vector field value from response."""
318
- if not vector_field_metadata or vector_field_metadata .name not in feature_name_to_index :
361
+ if (
362
+ not vector_field_metadata
363
+ or vector_field_metadata .name not in feature_name_to_index
364
+ ):
319
365
return ValueProto ()
320
366
321
367
vector_feature_idx = feature_name_to_index [vector_field_metadata .name ]
322
368
if self ._is_feature_present (response_json , vector_feature_idx , row_idx ):
323
- return self ._extract_feature_value (response_json , vector_feature_idx , row_idx )
369
+ return self ._extract_feature_value (
370
+ response_json , vector_feature_idx , row_idx
371
+ )
324
372
325
373
return ValueProto ()
326
374
@@ -329,22 +377,26 @@ def _extract_distance_value(
329
377
response_json : dict ,
330
378
feature_name_to_index : dict ,
331
379
distance_feature_name : str ,
332
- row_idx : int
333
- ) -> ValueProto :
380
+ row_idx : int ,
381
+ ) -> Optional [ ValueProto ] :
334
382
"""Extract distance/score value from response."""
335
383
if not distance_feature_name :
336
384
return ValueProto ()
337
385
338
386
distance_feature_idx = feature_name_to_index [distance_feature_name ]
339
387
if self ._is_feature_present (response_json , distance_feature_idx , row_idx ):
340
- distance_value = response_json ["results" ][distance_feature_idx ]["values" ][row_idx ]
388
+ distance_value = response_json ["results" ][distance_feature_idx ]["values" ][
389
+ row_idx
390
+ ]
341
391
distance_val = ValueProto ()
342
392
distance_val .float_val = float (distance_value )
343
393
return distance_val
344
394
345
395
return ValueProto ()
346
396
347
- def _is_feature_present (self , response_json : dict , feature_idx : int , row_idx : int ) -> bool :
397
+ def _is_feature_present (
398
+ self , response_json : dict , feature_idx : int , row_idx : int
399
+ ) -> bool :
348
400
"""Check if a feature is present in the response."""
349
401
return response_json ["results" ][feature_idx ]["statuses" ][row_idx ] == "PRESENT"
350
402
@@ -432,12 +484,19 @@ def _get_event_ts(self, response_json) -> datetime:
432
484
return datetime .fromisoformat (event_ts .replace ("Z" , "+00:00" ))
433
485
434
486
def _construct_entity_key_from_response (
435
- self , response_json : dict , row_idx : int , feature_name_to_index : dict
487
+ self ,
488
+ response_json : dict ,
489
+ row_idx : int ,
490
+ feature_name_to_index : dict ,
491
+ table : FeatureView ,
436
492
) -> Optional [EntityKeyProto ]:
437
493
"""Construct EntityKeyProto from response data."""
438
- # Look for entity key fields in the response
439
- entity_fields = [name for name in feature_name_to_index .keys ()
440
- if name .endswith ('_id' ) or name in ['id' , 'key' , 'entity_id' ]]
494
+ # Use the feature view's join_keys to identify entity fields
495
+ entity_fields = [
496
+ join_key
497
+ for join_key in table .join_keys
498
+ if join_key in feature_name_to_index
499
+ ]
441
500
442
501
if not entity_fields :
443
502
return None
@@ -449,12 +508,16 @@ def _construct_entity_key_from_response(
449
508
if entity_field in feature_name_to_index :
450
509
feature_idx = feature_name_to_index [entity_field ]
451
510
if self ._is_feature_present (response_json , feature_idx , row_idx ):
452
- entity_value = self ._extract_feature_value (response_json , feature_idx , row_idx )
511
+ entity_value = self ._extract_feature_value (
512
+ response_json , feature_idx , row_idx
513
+ )
453
514
entity_key_proto .entity_values .append (entity_value )
454
515
455
516
return entity_key_proto if entity_key_proto .entity_values else None
456
517
457
- def _extract_feature_value (self , response_json : dict , feature_idx : int , row_idx : int ) -> ValueProto :
518
+ def _extract_feature_value (
519
+ self , response_json : dict , feature_idx : int , row_idx : int
520
+ ) -> ValueProto :
458
521
"""Extract and convert a feature value to ValueProto."""
459
522
raw_value = response_json ["results" ][feature_idx ]["values" ][row_idx ]
460
523
if raw_value is None :
0 commit comments