Skip to content

Commit 85b2bc4

Browse files
committed
Set up initial project structure
Also add bare-bones test which can be run like so: ``` uvx pytest tests/test_postprocess.py ``` as long as `uvx` is in path.
1 parent 8a3dab6 commit 85b2bc4

File tree

18 files changed

+1591
-0
lines changed

18 files changed

+1591
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
**/tests/
1111
/build
1212
*.pyc
13+
**/__pycache__
14+
*.egg-info/
1315
.vagrant
1416
**/compile_commands.json
1517
.python-version

c2rust-postprocess/README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# LLM-based postprocessing of c2rust transpiler output
2+
3+
This is currently a prototype effort to gauge the extent to which LLMs can
4+
accelerate the types of translation and migration that help move C code to Rust.
5+
6+
# Prerequisites
7+
8+
- Python 3.12 or later
9+
- `uv` in path
10+
- A valid `GEMINI_API_KEY` set
11+
- A transpiled codebase with a correct `compile_commands.json`
12+
13+
# Running
14+
15+
- `c2rust-postprocess path/to/compile_commands.json`, or
16+
- `uv run postprocess path/to/compile_commands.json`
17+
18+
# Testing
19+
20+
## Test prerequisites
21+
22+
- `bear` and `c2rust` in path
23+
24+
```
25+
uv run pytest -v
26+
uv run pytest -v tests/test_utils.py # filter tests to run
27+
```
28+
29+
## Misc
30+
31+
- `uv run ruff check --fix .` to format & lint
32+
33+
# TODOs
34+
35+
- testable prototype
36+
- [x] gemini api support
37+
+ using synchronous API, tabled async API for now
38+
- file-based caching of model responses
39+
+ storage format could be improved to make it easier to create
40+
golden input/output pairs for testing
41+
- pluggable support for getting definitions
42+
- verifying correctness of responses
43+
- filtering by file and function name
44+
- openai model support
45+
- antropic model support
46+
- openrouter API support?
47+
- non-trivial: use async support to speed up postprocessing
48+
+ supported by gemini api, IDK about others
49+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#!/bin/sh
2+
uv run postproc
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
c2rust-postprocess: Transfer comments from C functions to Rust functions using LLMs.
3+
"""
4+
5+
import logging
6+
from pathlib import Path
7+
from textwrap import dedent
8+
from typing import Any
9+
10+
from postprocess.cache import AbstractCache
11+
from postprocess.definitions import get_c_comments, get_function_span_pairs
12+
from postprocess.models import get_model_by_id
13+
from postprocess.utils import get_compile_commands, get_rust_files, read_chunk, remove_backticks
14+
15+
from pygments import highlight
16+
from pygments.lexers import RustLexer
17+
from pygments.formatters import TerminalFormatter
18+
19+
# TODO: could also include
20+
# - validation function to check result
21+
# - list of comments to check for
22+
class CommentTransferPrompt:
23+
c_function: str
24+
rust_function: str
25+
prompt_text: str
26+
27+
__slots__ = ("c_function", "rust_function", "prompt_text")
28+
29+
def __init__(self, c_function: str, rust_function: str, prompt_text: str) -> None:
30+
self.c_function = c_function
31+
self.rust_function = rust_function
32+
self.prompt_text = prompt_text
33+
34+
def __str__(self) -> str:
35+
return self.prompt_text + "\n\n" + \
36+
"C function:\n```c\n" + self.c_function + "```\n\n" + \
37+
"Rust function:\n```rust\n" + self.rust_function + "```\n"
38+
39+
40+
def generate_prompts(
41+
compile_commands: list[dict[str, Any]], rust_file: Path
42+
) -> list[CommentTransferPrompt]:
43+
pairs = get_function_span_pairs(compile_commands, rust_file)
44+
45+
prompts = []
46+
47+
for rust_fn, c_fn in pairs:
48+
c_def = read_chunk(c_fn["file"], c_fn["start_byte"], c_fn["end_byte"])
49+
c_comments = get_c_comments(c_def)
50+
if not c_comments:
51+
logging.info(f"Skipping C function without comments: {c_fn['name']}")
52+
continue
53+
54+
# TODO: log on verbose level
55+
# print(f"C function {c_fn['name']} definition:\n{c_def}\n")
56+
57+
rust_def = read_chunk(
58+
rust_fn["file"], rust_fn["start_byte"], rust_fn["end_byte"]
59+
)
60+
# TODO: log on verbose level
61+
# print(f"Rust function {rust_fn['name']} definition:\n{rust_def}\n")
62+
63+
# TODO: make this function take a model and get prompt from model
64+
prompt_text = """
65+
Transfer the comments from the following C function to the corresponding Rust function.
66+
Do not add any comments that are not present in the C function.
67+
Respond with the Rust function definition with the transferred comments; say nothing else.
68+
""" # noqa: E501
69+
prompt_text = dedent(prompt_text).strip()
70+
71+
prompt = CommentTransferPrompt(
72+
c_function=c_def, rust_function=rust_def, prompt_text=prompt_text
73+
)
74+
75+
prompts.append(prompt)
76+
77+
return prompts
78+
79+
80+
# TODO: get from model
81+
SYSTEM_INSTRUCTION = (
82+
"You are a helpful assistant that transfers comments from C code to Rust code."
83+
)
84+
85+
def transfer_comments(compile_commands_path: Path, cache: AbstractCache) -> None:
86+
# TODO: instantiate the model based on command line args
87+
# TODO: avoid google-specific import here
88+
from google.genai import types
89+
model = get_model_by_id(
90+
"gemini-3-pro-preview",
91+
generation_config = {"system_instruction": types.Content(
92+
role="system",
93+
parts=[types.Part.from_text(text=SYSTEM_INSTRUCTION)]
94+
)}
95+
)
96+
97+
rust_sources = get_rust_files(compile_commands_path.parent)
98+
99+
compile_commands = get_compile_commands(compile_commands_path)
100+
101+
for rust_file in rust_sources:
102+
prompts = generate_prompts(compile_commands, rust_file)
103+
104+
for prompt in prompts:
105+
messages = [
106+
{"role": "user", "content": str(prompt)},
107+
]
108+
109+
if not (response := cache.lookup(messages)):
110+
response = model.generate_with_tools(messages)
111+
if response is None:
112+
logging.error("Model returned no response")
113+
continue
114+
cache.update(messages, response)
115+
116+
response = remove_backticks(response)
117+
118+
if True: # TODO: detect when terminal supports colors
119+
highlighted_response = highlight(response, RustLexer(), TerminalFormatter())
120+
121+
print("Response:\n", highlighted_response)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import argparse
2+
import logging
3+
import sys
4+
from collections.abc import Sequence
5+
6+
from postprocess import transfer_comments
7+
from postprocess.cache import DirectoryCache
8+
from postprocess.utils import existing_file
9+
10+
11+
def build_arg_parser() -> argparse.ArgumentParser:
12+
parser = argparse.ArgumentParser(
13+
description="Transfer C function comments to Rust using LLMs.",
14+
)
15+
parser.add_argument(
16+
"compile_commands",
17+
type=existing_file,
18+
help="Path to compile_commands.json.",
19+
)
20+
21+
parser.add_argument(
22+
"--log-level",
23+
type=str,
24+
required=False,
25+
default="INFO",
26+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
27+
help="Logging level (default: INFO)",
28+
)
29+
30+
return parser
31+
32+
33+
def main(argv: Sequence[str] | None = None) -> int:
34+
parser = build_arg_parser()
35+
args = parser.parse_args(argv)
36+
37+
logging.basicConfig(level=logging.getLevelName(args.log_level.upper()))
38+
39+
cache = DirectoryCache()
40+
41+
transfer_comments(args.compile_commands, cache)
42+
43+
return 0
44+
45+
46+
if __name__ == "__main__":
47+
try:
48+
sys.exit(main())
49+
except KeyboardInterrupt as e:
50+
logging.warning("Interrupted by user, terminating...")
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from abc import ABC, abstractmethod
2+
from pathlib import Path
3+
from tempfile import gettempdir
4+
from typing import Any
5+
6+
7+
class AbstractCache(ABC):
8+
"""
9+
Abstract base class for caching of LLM interactions.
10+
"""
11+
12+
def __init__(self, path: Path, **kwargs: Any):
13+
self._path = path
14+
self._config = kwargs
15+
16+
@property
17+
def path(self) -> Path:
18+
return self._path
19+
20+
@abstractmethod
21+
def lookup(
22+
self,
23+
messages: list[dict[str, Any]],
24+
25+
) -> str | None:
26+
"""Lookup a cached response for the given messages.
27+
28+
Args:
29+
messages: The list of messages representing the conversation history.
30+
"""
31+
pass
32+
33+
@abstractmethod
34+
def update(
35+
self,
36+
messages: list[dict[str, Any]],
37+
response: str
38+
) -> None:
39+
"""Store a response in the cache for the given messages.
40+
41+
Args:
42+
messages: The list of messages representing the conversation history.
43+
response: The response text to cache.
44+
"""
45+
pass
46+
47+
@abstractmethod
48+
def clear(self) -> None:
49+
"""Clear the entire cache."""
50+
pass
51+
52+
def flush(self) -> None: # noqa: B027
53+
"""
54+
Optional: Persist cache to disk.
55+
Not abstract because not all implementations need it.
56+
"""
57+
pass
58+
59+
60+
class DirectoryCache(AbstractCache):
61+
"""
62+
Cache that stores cached responses in a directory.
63+
If no path is specified, a temporary directory is used.
64+
"""
65+
66+
def __init__(self, path: Path | None = None, **kwargs: Any):
67+
if path is None:
68+
path = Path(gettempdir()) / "c2rust_postprocess"
69+
super().__init__(path, **kwargs)
70+
self._path.mkdir(parents=True, exist_ok=True)
71+
72+
def get_cache_file_name(self, messages: list[dict[str, Any]]) -> Path:
73+
import hashlib
74+
import json
75+
76+
messages_str = json.dumps(messages, sort_keys=True)
77+
hash_digest = hashlib.sha256(messages_str.encode()).hexdigest()
78+
return self._path / f"{hash_digest}.txt"
79+
80+
def lookup(
81+
self,
82+
messages: list[dict[str, Any]],
83+
) -> str | None:
84+
cache_file = self.get_cache_file_name(messages)
85+
86+
if cache_file.exists():
87+
with open(cache_file, encoding='utf-8') as f:
88+
return f.read()
89+
return None
90+
91+
def update(
92+
self,
93+
messages: list[dict[str, Any]],
94+
response: str
95+
) -> None:
96+
cache_file = self.get_cache_file_name(messages)
97+
98+
with open(cache_file, 'w', encoding='utf-8') as f:
99+
f.write(response)
100+
101+
def clear(self) -> None:
102+
self._path.unlink(missing_ok=True)

0 commit comments

Comments
 (0)