Skip to content

Commit dcc6339

Browse files
committed
Unit tests for Remote docuemnts retrival
Signed-off-by: jyejare <jyejare@redhat.com>
1 parent aff78a6 commit dcc6339

File tree

1 file changed

+366
-0
lines changed

1 file changed

+366
-0
lines changed
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
import json
2+
import pytest
3+
from datetime import datetime, timedelta
4+
from unittest.mock import Mock, patch, MagicMock
5+
from typing import List, Optional
6+
7+
from feast import Entity, FeatureView, Field, FileSource, RepoConfig
8+
from feast.infra.online_stores.remote import RemoteOnlineStore, RemoteOnlineStoreConfig
9+
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
10+
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
11+
from feast.types import Float32, String, Int64
12+
from feast.value_type import ValueType
13+
14+
15+
class TestRemoteOnlineStoreRetrieveDocuments:
16+
"""Test suite for retrieve_online_documents and retrieve_online_documents_v2 methods."""
17+
18+
@pytest.fixture
19+
def remote_store(self):
20+
"""Create a RemoteOnlineStore instance for testing."""
21+
return RemoteOnlineStore()
22+
23+
@pytest.fixture
24+
def config(self):
25+
"""Create a RepoConfig with RemoteOnlineStoreConfig."""
26+
return RepoConfig(
27+
project="test_project",
28+
online_store=RemoteOnlineStoreConfig(
29+
type="remote",
30+
path="http://localhost:6566"
31+
),
32+
registry="dummy_registry"
33+
)
34+
35+
@pytest.fixture
36+
def config_with_cert(self):
37+
"""Create a RepoConfig with RemoteOnlineStoreConfig including TLS cert."""
38+
return RepoConfig(
39+
project="test_project",
40+
online_store=RemoteOnlineStoreConfig(
41+
type="remote",
42+
path="http://localhost:6566",
43+
cert="/path/to/cert.pem"
44+
),
45+
registry="dummy_registry"
46+
)
47+
48+
@pytest.fixture
49+
def feature_view(self):
50+
"""Create a test FeatureView."""
51+
entity = Entity(name="user_id", description="User ID", value_type=ValueType.INT64)
52+
source = FileSource(
53+
path="test.parquet",
54+
timestamp_field="event_timestamp"
55+
)
56+
return FeatureView(
57+
name="test_feature_view",
58+
entities=[entity],
59+
ttl=timedelta(days=1),
60+
schema=[
61+
Field(name="user_id", dtype=Int64), # Entity field
62+
Field(name="feature1", dtype=String),
63+
Field(name="embedding", dtype=Float32),
64+
],
65+
source=source,
66+
)
67+
68+
@pytest.fixture
69+
def mock_successful_response(self):
70+
"""Create a mock successful HTTP response for documents retrieval."""
71+
return {
72+
"metadata": {
73+
"feature_names": ["feature1", "embedding", "distance", "user_id"]
74+
},
75+
"results": [
76+
{
77+
"values": ["test_value_1", "test_value_2"],
78+
"statuses": ["PRESENT", "PRESENT"]
79+
}, # feature1
80+
{
81+
"values": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
82+
"statuses": ["PRESENT", "PRESENT"],
83+
"event_timestamps": ["2023-01-01T00:00:00Z", "2023-01-01T01:00:00Z"]
84+
}, # embedding
85+
{
86+
"values": [0.85, 0.92],
87+
"statuses": ["PRESENT", "PRESENT"]
88+
}, # distance
89+
{
90+
"values": [123, 456],
91+
"statuses": ["PRESENT", "PRESENT"]
92+
} # user_id
93+
]
94+
}
95+
96+
@pytest.fixture
97+
def mock_successful_response_v2(self):
98+
"""Create a mock successful HTTP response for documents retrieval v2."""
99+
return {
100+
"metadata": {
101+
"feature_names": ["user_id", "feature1"]
102+
},
103+
"results": [
104+
{
105+
"values": [123, 456],
106+
"statuses": ["PRESENT", "PRESENT"]
107+
}, # user_id
108+
{
109+
"values": ["test_value_1", "test_value_2"],
110+
"statuses": ["PRESENT", "PRESENT"],
111+
"event_timestamps": ["2023-01-01T00:00:00Z", "2023-01-01T01:00:00Z"]
112+
} # feature1
113+
]
114+
}
115+
116+
@patch('feast.infra.online_stores.remote.get_remote_online_documents')
117+
def test_retrieve_online_documents_success(
118+
self,
119+
mock_get_remote_online_documents,
120+
remote_store,
121+
config,
122+
feature_view,
123+
mock_successful_response
124+
):
125+
"""Test successful retrieve_online_documents call."""
126+
# Setup mock response
127+
mock_response = Mock()
128+
mock_response.status_code = 200
129+
mock_response.text = json.dumps(mock_successful_response)
130+
mock_get_remote_online_documents.return_value = mock_response
131+
132+
# Call the method
133+
result = remote_store.retrieve_online_documents(
134+
config=config,
135+
table=feature_view,
136+
requested_features=["feature1"],
137+
embedding=[0.1, 0.2, 0.3],
138+
top_k=2,
139+
distance_metric="L2"
140+
)
141+
142+
# Verify the call was made correctly
143+
mock_get_remote_online_documents.assert_called_once()
144+
call_args = mock_get_remote_online_documents.call_args
145+
assert call_args[1]['config'] == config
146+
147+
# Parse the request body to verify it's correct
148+
req_body = json.loads(call_args[1]['req_body'])
149+
assert req_body['features'] == ['test_feature_view:feature1']
150+
assert req_body['query'] == [0.1, 0.2, 0.3]
151+
assert req_body['top_k'] == 2
152+
assert req_body['distance_metric'] == "L2"
153+
154+
# Verify the result
155+
assert len(result) == 2
156+
event_ts, entity_key_proto, feature_val, vector_value, distance_val = result[0]
157+
158+
# Check event timestamp
159+
assert isinstance(event_ts, datetime)
160+
161+
# Check that we got ValueProto objects
162+
assert isinstance(feature_val, ValueProto)
163+
assert isinstance(vector_value, ValueProto)
164+
assert isinstance(distance_val, ValueProto)
165+
166+
@patch('feast.infra.online_stores.remote.get_remote_online_documents')
167+
def test_retrieve_online_documents_v2_success(
168+
self,
169+
mock_get_remote_online_documents,
170+
remote_store,
171+
config,
172+
feature_view,
173+
mock_successful_response_v2
174+
):
175+
"""Test successful retrieve_online_documents_v2 call."""
176+
# Setup mock response
177+
mock_response = Mock()
178+
mock_response.status_code = 200
179+
mock_response.text = json.dumps(mock_successful_response_v2)
180+
mock_get_remote_online_documents.return_value = mock_response
181+
182+
# Call the method
183+
result = remote_store.retrieve_online_documents_v2(
184+
config=config,
185+
table=feature_view,
186+
requested_features=["feature1"],
187+
embedding=[0.1, 0.2, 0.3],
188+
top_k=2,
189+
distance_metric="cosine",
190+
query_string="test query"
191+
)
192+
193+
# Verify the call was made correctly
194+
mock_get_remote_online_documents.assert_called_once()
195+
call_args = mock_get_remote_online_documents.call_args
196+
assert call_args[1]['config'] == config
197+
198+
# Parse the request body to verify it's correct
199+
req_body = json.loads(call_args[1]['req_body'])
200+
assert req_body['features'] == ['test_feature_view:feature1']
201+
assert req_body['query'] == [0.1, 0.2, 0.3]
202+
assert req_body['top_k'] == 2
203+
assert req_body['distance_metric'] == "cosine"
204+
assert req_body['query_string'] == "test query"
205+
assert req_body['api_version'] == 2
206+
207+
# Verify the result
208+
assert len(result) == 2
209+
event_ts, entity_key_proto, feature_values_dict = result[0]
210+
211+
# Check event timestamp
212+
assert isinstance(event_ts, datetime)
213+
214+
# Check entity key proto
215+
assert isinstance(entity_key_proto, EntityKeyProto)
216+
217+
# Check feature values dictionary
218+
assert isinstance(feature_values_dict, dict)
219+
assert "feature1" in feature_values_dict
220+
assert isinstance(feature_values_dict["feature1"], ValueProto)
221+
222+
@patch('feast.infra.online_stores.remote.get_remote_online_documents')
223+
def test_retrieve_online_documents_with_cert(
224+
self,
225+
mock_get_remote_online_documents,
226+
remote_store,
227+
config_with_cert,
228+
feature_view,
229+
mock_successful_response
230+
):
231+
"""Test retrieve_online_documents with TLS certificate."""
232+
# Setup mock response
233+
mock_response = Mock()
234+
mock_response.status_code = 200
235+
mock_response.text = json.dumps(mock_successful_response)
236+
mock_get_remote_online_documents.return_value = mock_response
237+
238+
# Call the method
239+
result = remote_store.retrieve_online_documents(
240+
config=config_with_cert,
241+
table=feature_view,
242+
requested_features=["feature1"],
243+
embedding=[0.1, 0.2, 0.3],
244+
top_k=1
245+
)
246+
247+
# Verify the call was made
248+
mock_get_remote_online_documents.assert_called_once()
249+
assert len(result) == 2
250+
251+
@patch('feast.infra.online_stores.remote.get_remote_online_documents')
252+
def test_retrieve_online_documents_error_response(
253+
self,
254+
mock_get_remote_online_documents,
255+
remote_store,
256+
config,
257+
feature_view
258+
):
259+
"""Test retrieve_online_documents with error response."""
260+
# Setup mock error response
261+
mock_response = Mock()
262+
mock_response.status_code = 500
263+
mock_response.text = "Internal Server Error"
264+
mock_get_remote_online_documents.return_value = mock_response
265+
266+
# Call the method and expect RuntimeError
267+
with pytest.raises(RuntimeError, match="Unable to retrieve the online documents using feature server API"):
268+
remote_store.retrieve_online_documents(
269+
config=config,
270+
table=feature_view,
271+
requested_features=["feature1"],
272+
embedding=[0.1, 0.2, 0.3],
273+
top_k=1
274+
)
275+
276+
@patch('feast.infra.online_stores.remote.get_remote_online_documents')
277+
def test_retrieve_online_documents_v2_error_response(
278+
self,
279+
mock_get_remote_online_documents,
280+
remote_store,
281+
config,
282+
feature_view
283+
):
284+
"""Test retrieve_online_documents_v2 with error response."""
285+
# Setup mock error response
286+
mock_response = Mock()
287+
mock_response.status_code = 404
288+
mock_response.text = "Not Found"
289+
mock_get_remote_online_documents.return_value = mock_response
290+
291+
# Call the method and expect RuntimeError
292+
with pytest.raises(RuntimeError, match="Unable to retrieve the online documents using feature server API"):
293+
remote_store.retrieve_online_documents_v2(
294+
config=config,
295+
table=feature_view,
296+
requested_features=["feature1"],
297+
embedding=[0.1, 0.2, 0.3],
298+
top_k=1
299+
)
300+
301+
def test_construct_online_documents_api_json_request(self, remote_store, feature_view):
302+
"""Test _construct_online_documents_api_json_request method."""
303+
result = remote_store._construct_online_documents_api_json_request(
304+
table=feature_view,
305+
requested_features=["feature1", "feature2"],
306+
embedding=[0.1, 0.2, 0.3],
307+
top_k=5,
308+
distance_metric="cosine"
309+
)
310+
311+
parsed_result = json.loads(result)
312+
assert parsed_result["features"] == ["test_feature_view:feature1", "test_feature_view:feature2"]
313+
assert parsed_result["query"] == [0.1, 0.2, 0.3]
314+
assert parsed_result["top_k"] == 5
315+
assert parsed_result["distance_metric"] == "cosine"
316+
317+
def test_construct_online_documents_v2_api_json_request(self, remote_store, feature_view):
318+
"""Test _construct_online_documents_v2_api_json_request method."""
319+
result = remote_store._construct_online_documents_v2_api_json_request(
320+
table=feature_view,
321+
requested_features=["feature1"],
322+
embedding=[0.1, 0.2],
323+
top_k=3,
324+
distance_metric="L2",
325+
query_string="test query",
326+
api_version=2
327+
)
328+
329+
parsed_result = json.loads(result)
330+
assert parsed_result["features"] == ["test_feature_view:feature1"]
331+
assert parsed_result["query"] == [0.1, 0.2]
332+
assert parsed_result["top_k"] == 3
333+
assert parsed_result["distance_metric"] == "L2"
334+
assert parsed_result["query_string"] == "test query"
335+
assert parsed_result["api_version"] == 2
336+
337+
338+
def test_extract_requested_feature_value(self, remote_store):
339+
"""Test _extract_requested_feature_value helper method."""
340+
response_json = {
341+
"results": [
342+
{
343+
"values": ["test_value"],
344+
"statuses": ["PRESENT"]
345+
}
346+
]
347+
}
348+
feature_name_to_index = {"feature1": 0}
349+
350+
result = remote_store._extract_requested_feature_value(
351+
response_json, feature_name_to_index, ["feature1"], 0
352+
)
353+
assert isinstance(result, ValueProto)
354+
355+
def test_is_feature_present(self, remote_store):
356+
"""Test _is_feature_present helper method."""
357+
response_json = {
358+
"results": [
359+
{
360+
"statuses": ["PRESENT", "NOT_FOUND"]
361+
}
362+
]
363+
}
364+
365+
assert remote_store._is_feature_present(response_json, 0, 0) == True
366+
assert remote_store._is_feature_present(response_json, 0, 1) == False

0 commit comments

Comments
 (0)