Skip to content

Commit b87816c

Browse files
authored
test: add unit tests to execute transform flow e2e (#737)
* test: add unit tests to execute transform flow e2e * ops: update pip installation command * refactor: apply for_each iteration in the transform flow * test: update to correctly process child rows and return expected structure
1 parent d7e480d commit b87816c

File tree

3 files changed

+105
-2
lines changed

3 files changed

+105
-2
lines changed

.github/workflows/_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
- name: Install Python toolchains
4444
run: |
4545
source .venv/bin/activate
46-
pip install maturin pytest mypy
46+
pip install maturin mypy pytest pytest-asyncio
4747
- name: Python build
4848
run: |
4949
source .venv/bin/activate

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ module-name = "cocoindex._engine"
2929
features = ["pyo3/extension-module"]
3030

3131
[project.optional-dependencies]
32-
dev = ["pytest", "ruff", "mypy", "pre-commit"]
32+
dev = ["pytest", "pytest-asyncio", "ruff", "mypy", "pre-commit"]
3333

3434
embeddings = ["sentence-transformers>=3.3.1"]
3535

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import typing
2+
from dataclasses import dataclass
3+
from typing import Any
4+
5+
import pytest
6+
7+
import cocoindex
8+
9+
10+
@dataclass
11+
class Child:
12+
value: int
13+
14+
15+
@dataclass
16+
class Parent:
17+
children: list[Child]
18+
19+
20+
# Fixture to initialize CocoIndex library
21+
@pytest.fixture(scope="session", autouse=True)
22+
def init_cocoindex() -> typing.Generator[None, None, None]:
23+
cocoindex.init()
24+
yield
25+
26+
27+
@cocoindex.op.function()
28+
def add_suffix(text: str) -> str:
29+
"""Append ' world' to the input text."""
30+
return f"{text} world"
31+
32+
33+
@cocoindex.transform_flow()
34+
def simple_transform(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]:
35+
"""Transform flow that applies add_suffix to input text."""
36+
return text.transform(add_suffix)
37+
38+
39+
@cocoindex.op.function()
40+
def extract_value(value: int) -> int:
41+
"""Extracts the value."""
42+
return value
43+
44+
45+
@cocoindex.transform_flow()
46+
def for_each_transform(
47+
data: cocoindex.DataSlice[Parent],
48+
) -> cocoindex.DataSlice[Any]:
49+
"""Transform flow that processes child rows to extract values."""
50+
with data["children"].row() as child:
51+
child["new_field"] = child["value"].transform(extract_value)
52+
return data
53+
54+
55+
def test_simple_transform_flow() -> None:
56+
"""Test the simple transform flow."""
57+
input_text = "hello"
58+
result = simple_transform.eval(input_text)
59+
assert result == "hello world", f"Expected 'hello world', got {result}"
60+
61+
result = simple_transform.eval("")
62+
assert result == " world", f"Expected ' world', got {result}"
63+
64+
65+
@pytest.mark.asyncio
66+
async def test_simple_transform_flow_async() -> None:
67+
"""Test the simple transform flow asynchronously."""
68+
input_text = "async"
69+
result = await simple_transform.eval_async(input_text)
70+
assert result == "async world", f"Expected 'async world', got {result}"
71+
72+
73+
def test_for_each_transform_flow() -> None:
74+
"""Test the complex transform flow with child rows."""
75+
input_data = Parent(children=[Child(1), Child(2), Child(3)])
76+
result = for_each_transform.eval(input_data)
77+
expected = {
78+
"children": [
79+
{"value": 1, "new_field": 1},
80+
{"value": 2, "new_field": 2},
81+
{"value": 3, "new_field": 3},
82+
]
83+
}
84+
assert result == expected, f"Expected {expected}, got {result}"
85+
86+
input_data = Parent(children=[])
87+
result = for_each_transform.eval(input_data)
88+
assert result == {"children": []}, f"Expected {{'children': []}}, got {result}"
89+
90+
91+
@pytest.mark.asyncio
92+
async def test_for_each_transform_flow_async() -> None:
93+
"""Test the complex transform flow asynchronously."""
94+
input_data = Parent(children=[Child(4), Child(5)])
95+
result = await for_each_transform.eval_async(input_data)
96+
expected = {
97+
"children": [
98+
{"value": 4, "new_field": 4},
99+
{"value": 5, "new_field": 5},
100+
]
101+
}
102+
103+
assert result == expected, f"Expected {expected}, got {result}"

0 commit comments

Comments
 (0)