Skip to content

Commit 48f27f7

Browse files
committed
Fixed checks
Signed-off-by: jyejare <jyejare@redhat.com>
1 parent dcc6339 commit 48f27f7

File tree

4 files changed

+212
-167
lines changed

4 files changed

+212
-167
lines changed

sdk/python/feast/feature_server.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,9 @@ async def retrieve_online_documents(
265265
# Initialize parameters for FeatureStore.retrieve_online_documents_v2(...) call
266266
features = await run_in_threadpool(_get_features, request, store)
267267

268-
read_params = dict(
269-
features=features,
270-
query=request.query,
271-
top_k=request.top_k
272-
)
268+
read_params = dict(features=features, query=request.query, top_k=request.top_k)
273269
if request.api_version == 2 and request.query_string is not None:
274-
read_params['query_string'] = request.query_string
270+
read_params["query_string"] = request.query_string
275271

276272
if request.api_version == 2:
277273
response = await run_in_threadpool(

sdk/python/feast/feature_store.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,7 +2102,7 @@ def retrieve_online_documents_v2(
21022102
distance_metric,
21032103
query_string,
21042104
)
2105-
2105+
21062106
def _retrieve_from_online_store(
21072107
self,
21082108
provider: Provider,
@@ -2245,7 +2245,6 @@ def _retrieve_from_online_store_v2(
22452245
data=entity_key_dict,
22462246
)
22472247

2248-
22492248
return OnlineResponse(online_features_response)
22502249

22512250
def serve(

sdk/python/feast/infra/online_stores/remote.py

Lines changed: 103 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,15 @@ def retrieve_online_documents(
179179
embedding: Optional[List[float]],
180180
top_k: int,
181181
distance_metric: Optional[str] = "L2",
182-
) -> List[Tuple[Optional[datetime], Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]]]:
182+
) -> List[
183+
Tuple[
184+
Optional[datetime],
185+
Optional[EntityKeyProto],
186+
Optional[ValueProto],
187+
Optional[ValueProto],
188+
Optional[ValueProto],
189+
]
190+
]:
183191
assert isinstance(config.online_store, RemoteOnlineStoreConfig)
184192
config.online_store.__class__ = RemoteOnlineStoreConfig
185193

@@ -190,18 +198,16 @@ def retrieve_online_documents(
190198
if response.status_code == 200:
191199
logger.debug("Able to retrieve the online documents from feature server.")
192200
response_json = json.loads(response.text)
193-
event_ts = self._get_event_ts(response_json)
201+
event_ts: Optional[datetime] = self._get_event_ts(response_json)
194202

195203
# Create feature name to index mapping for efficient lookup
196204
feature_name_to_index = {
197-
name: idx for idx, name in enumerate(response_json["metadata"]["feature_names"])
205+
name: idx
206+
for idx, name in enumerate(response_json["metadata"]["feature_names"])
198207
}
199208

200209
vector_field_metadata = _get_feature_view_vector_field_metadata(table)
201210

202-
# Extract feature names once
203-
feature_names = response_json["metadata"]["feature_names"]
204-
205211
# Process each result row
206212
num_results = len(response_json["results"][0]["values"])
207213
result_tuples = []
@@ -215,13 +221,21 @@ def retrieve_online_documents(
215221
response_json, feature_name_to_index, vector_field_metadata, row_idx
216222
)
217223
distance_val = self._extract_distance_value(
218-
response_json, feature_name_to_index, 'distance', row_idx
224+
response_json, feature_name_to_index, "distance", row_idx
219225
)
220226
entity_key_proto = self._construct_entity_key_from_response(
221-
response_json, row_idx, feature_name_to_index
227+
response_json, row_idx, feature_name_to_index, table
222228
)
223229

224-
result_tuples.append((event_ts, entity_key_proto, feature_val, vector_value, distance_val))
230+
result_tuples.append(
231+
(
232+
event_ts,
233+
entity_key_proto,
234+
feature_val,
235+
vector_value,
236+
distance_val,
237+
)
238+
)
225239

226240
return result_tuples
227241
else:
@@ -238,50 +252,77 @@ def retrieve_online_documents_v2(
238252
top_k: int,
239253
distance_metric: Optional[str] = None,
240254
query_string: Optional[str] = None,
241-
) -> List[Tuple[Optional[datetime], Optional[EntityKeyProto], Optional[Dict[str, ValueProto]]]]:
255+
) -> List[
256+
Tuple[
257+
Optional[datetime],
258+
Optional[EntityKeyProto],
259+
Optional[Dict[str, ValueProto]],
260+
]
261+
]:
242262
assert isinstance(config.online_store, RemoteOnlineStoreConfig)
243263
config.online_store.__class__ = RemoteOnlineStoreConfig
244264

245265
req_body = self._construct_online_documents_v2_api_json_request(
246-
table, requested_features, embedding, top_k, distance_metric, query_string, api_version=2
266+
table,
267+
requested_features,
268+
embedding,
269+
top_k,
270+
distance_metric,
271+
query_string,
272+
api_version=2,
247273
)
248274
response = get_remote_online_documents(config=config, req_body=req_body)
249275
if response.status_code == 200:
250276
logger.debug("Able to retrieve the online documents from feature server.")
251277
response_json = json.loads(response.text)
252-
event_ts = self._get_event_ts(response_json)
253-
278+
event_ts: Optional[datetime] = self._get_event_ts(response_json)
279+
254280
# Create feature name to index mapping for efficient lookup
255281
feature_name_to_index = {
256-
name: idx for idx, name in enumerate(response_json["metadata"]["feature_names"])
282+
name: idx
283+
for idx, name in enumerate(response_json["metadata"]["feature_names"])
257284
}
258285

259286
# Process each result row
260-
num_results = len(response_json["results"][0]["values"]) if response_json["results"] else 0
287+
num_results = (
288+
len(response_json["results"][0]["values"])
289+
if response_json["results"]
290+
else 0
291+
)
261292
result_tuples = []
262293

263294
for row_idx in range(num_results):
264295
# Build feature values dictionary for requested features
265-
feature_values_dict: Dict[str, ValueProto] = {}
266-
296+
feature_values_dict = {}
297+
267298
if requested_features:
268299
for feature_name in requested_features:
269300
if feature_name in feature_name_to_index:
270301
feature_idx = feature_name_to_index[feature_name]
271-
if self._is_feature_present(response_json, feature_idx, row_idx):
272-
feature_values_dict[feature_name] = self._extract_feature_value(
273-
response_json, feature_idx, row_idx
302+
if self._is_feature_present(
303+
response_json, feature_idx, row_idx
304+
):
305+
feature_values_dict[feature_name] = (
306+
self._extract_feature_value(
307+
response_json, feature_idx, row_idx
308+
)
274309
)
275310
else:
276311
feature_values_dict[feature_name] = ValueProto()
277312

278313
# Construct entity key proto using existing helper method
279314
entity_key_proto = self._construct_entity_key_from_response(
280-
response_json, row_idx, feature_name_to_index
315+
response_json, row_idx, feature_name_to_index, table
316+
)
317+
318+
result_tuples.append(
319+
(
320+
event_ts,
321+
entity_key_proto,
322+
feature_values_dict if feature_values_dict else None,
323+
)
281324
)
282325

283-
result_tuples.append((event_ts, entity_key_proto, feature_values_dict))
284-
285326
return result_tuples
286327
else:
287328
error_msg = f"Unable to retrieve the online documents using feature server API. Error_code={response.status_code}, error_message={response.text}"
@@ -293,8 +334,8 @@ def _extract_requested_feature_value(
293334
response_json: dict,
294335
feature_name_to_index: dict,
295336
requested_features: Optional[List[str]],
296-
row_idx: int
297-
) -> ValueProto:
337+
row_idx: int,
338+
) -> Optional[ValueProto]:
298339
"""Extract the first available requested feature value."""
299340
if not requested_features:
300341
return ValueProto()
@@ -303,7 +344,9 @@ def _extract_requested_feature_value(
303344
if feature_name in feature_name_to_index:
304345
feature_idx = feature_name_to_index[feature_name]
305346
if self._is_feature_present(response_json, feature_idx, row_idx):
306-
return self._extract_feature_value(response_json, feature_idx, row_idx)
347+
return self._extract_feature_value(
348+
response_json, feature_idx, row_idx
349+
)
307350

308351
return ValueProto()
309352

@@ -312,15 +355,20 @@ def _extract_vector_field_value(
312355
response_json: dict,
313356
feature_name_to_index: dict,
314357
vector_field_metadata,
315-
row_idx: int
316-
) -> ValueProto:
358+
row_idx: int,
359+
) -> Optional[ValueProto]:
317360
"""Extract vector field value from response."""
318-
if not vector_field_metadata or vector_field_metadata.name not in feature_name_to_index:
361+
if (
362+
not vector_field_metadata
363+
or vector_field_metadata.name not in feature_name_to_index
364+
):
319365
return ValueProto()
320366

321367
vector_feature_idx = feature_name_to_index[vector_field_metadata.name]
322368
if self._is_feature_present(response_json, vector_feature_idx, row_idx):
323-
return self._extract_feature_value(response_json, vector_feature_idx, row_idx)
369+
return self._extract_feature_value(
370+
response_json, vector_feature_idx, row_idx
371+
)
324372

325373
return ValueProto()
326374

@@ -329,22 +377,26 @@ def _extract_distance_value(
329377
response_json: dict,
330378
feature_name_to_index: dict,
331379
distance_feature_name: str,
332-
row_idx: int
333-
) -> ValueProto:
380+
row_idx: int,
381+
) -> Optional[ValueProto]:
334382
"""Extract distance/score value from response."""
335383
if not distance_feature_name:
336384
return ValueProto()
337385

338386
distance_feature_idx = feature_name_to_index[distance_feature_name]
339387
if self._is_feature_present(response_json, distance_feature_idx, row_idx):
340-
distance_value = response_json["results"][distance_feature_idx]["values"][row_idx]
388+
distance_value = response_json["results"][distance_feature_idx]["values"][
389+
row_idx
390+
]
341391
distance_val = ValueProto()
342392
distance_val.float_val = float(distance_value)
343393
return distance_val
344394

345395
return ValueProto()
346396

347-
def _is_feature_present(self, response_json: dict, feature_idx: int, row_idx: int) -> bool:
397+
def _is_feature_present(
398+
self, response_json: dict, feature_idx: int, row_idx: int
399+
) -> bool:
348400
"""Check if a feature is present in the response."""
349401
return response_json["results"][feature_idx]["statuses"][row_idx] == "PRESENT"
350402

@@ -432,12 +484,19 @@ def _get_event_ts(self, response_json) -> datetime:
432484
return datetime.fromisoformat(event_ts.replace("Z", "+00:00"))
433485

434486
def _construct_entity_key_from_response(
435-
self, response_json: dict, row_idx: int, feature_name_to_index: dict
487+
self,
488+
response_json: dict,
489+
row_idx: int,
490+
feature_name_to_index: dict,
491+
table: FeatureView,
436492
) -> Optional[EntityKeyProto]:
437493
"""Construct EntityKeyProto from response data."""
438-
# Look for entity key fields in the response
439-
entity_fields = [name for name in feature_name_to_index.keys()
440-
if name.endswith('_id') or name in ['id', 'key', 'entity_id']]
494+
# Use the feature view's join_keys to identify entity fields
495+
entity_fields = [
496+
join_key
497+
for join_key in table.join_keys
498+
if join_key in feature_name_to_index
499+
]
441500

442501
if not entity_fields:
443502
return None
@@ -449,12 +508,16 @@ def _construct_entity_key_from_response(
449508
if entity_field in feature_name_to_index:
450509
feature_idx = feature_name_to_index[entity_field]
451510
if self._is_feature_present(response_json, feature_idx, row_idx):
452-
entity_value = self._extract_feature_value(response_json, feature_idx, row_idx)
511+
entity_value = self._extract_feature_value(
512+
response_json, feature_idx, row_idx
513+
)
453514
entity_key_proto.entity_values.append(entity_value)
454515

455516
return entity_key_proto if entity_key_proto.entity_values else None
456517

457-
def _extract_feature_value(self, response_json: dict, feature_idx: int, row_idx: int) -> ValueProto:
518+
def _extract_feature_value(
519+
self, response_json: dict, feature_idx: int, row_idx: int
520+
) -> ValueProto:
458521
"""Extract and convert a feature value to ValueProto."""
459522
raw_value = response_json["results"][feature_idx]["values"][row_idx]
460523
if raw_value is None:

0 commit comments

Comments
 (0)