|
1 | 1 | import argparse
|
| 2 | +import csv |
2 | 3 | import os
|
3 |
| -from prompttrail.agent.templates import LinearTemplate |
| 4 | +from io import StringIO |
| 5 | + |
4 | 6 | from prompttrail.agent.runners import CommandLineRunner
|
5 |
| -from prompttrail.models.openai import OpenAIChatCompletionModel, OpenAIModelConfiguration, OpenAIModelParameters |
| 7 | +from prompttrail.agent.templates import LinearTemplate |
| 8 | +from prompttrail.agent.templates.openai import ( |
| 9 | + OpenAIGenerateTemplate, |
| 10 | + OpenAIMessageTemplate, |
| 11 | +) |
6 | 12 | from prompttrail.agent.user_interaction import UserInteractionTextCLIProvider
|
| 13 | +from prompttrail.models.openai import ( |
| 14 | + OpenAIChatCompletionModel, |
| 15 | + OpenAIModelConfiguration, |
| 16 | + OpenAIModelParameters, |
| 17 | +) |
7 | 18 |
|
8 | 19 |
|
9 | 20 | def extract(error_message: str):
|
10 |
| - return f"""# Error Summary |
| 21 | + template = LinearTemplate( |
| 22 | + [ |
| 23 | + OpenAIMessageTemplate( |
| 24 | + role="system", |
| 25 | + content=""" |
| 26 | +You're an AI assistant that helps software engineer to understand test error message. |
| 27 | +Your input is error message that emitted by test codes. |
| 28 | +Your output is the summary csv of the error message. |
| 29 | +Each line in the summary csv descributes each error in the error message. |
| 30 | +the summary csv first column indicates where the error occurred, second column is summary of error. |
| 31 | +You emit ONLY the summary csv. No explanation is needed. |
| 32 | +""", |
| 33 | + ), |
| 34 | + OpenAIMessageTemplate( |
| 35 | + role="user", |
| 36 | + content=error_message, |
| 37 | + ), |
| 38 | + OpenAIGenerateTemplate(role="assistant"), |
| 39 | + ] |
| 40 | + ) |
| 41 | + runner = CommandLineRunner( |
| 42 | + model=OpenAIChatCompletionModel( |
| 43 | + configuration=OpenAIModelConfiguration( |
| 44 | + api_key=os.environ.get("OPENAI_API_KEY", "") |
| 45 | + ) |
| 46 | + ), |
| 47 | + parameters=OpenAIModelParameters(model_name="gpt-3.5-turbo-0301"), |
| 48 | + template=template, |
| 49 | + user_interaction_provider=UserInteractionTextCLIProvider(), |
| 50 | + ) |
| 51 | + state = runner.run() |
| 52 | + last_message_content = state.get_last_message().content |
| 53 | + # remove response message header |
| 54 | + csv_str = "\n".join(last_message_content.splitlines()[2:]) |
| 55 | + with StringIO() as f: |
| 56 | + f.write(csv_str) |
| 57 | + f.seek(0) |
| 58 | + csv_content = list(csv.reader(f)) |
| 59 | + content = "\n".join([f"| {row[0]} | {row[1]} |" for row in csv_content]) |
| 60 | + return f"""# Error Summaries |
11 | 61 |
|
12 |
| -TODO: impl summarization of error_message by PromptTail. |
13 |
| -error_mesasge len: {len(error_message)} |
| 62 | +| where | summary | |
| 63 | +| ----- | ------- | |
| 64 | +{content} |
14 | 65 | """
|
15 | 66 |
|
16 | 67 |
|
17 | 68 | def main() -> None:
|
18 | 69 | parser = argparse.ArgumentParser()
|
19 |
| - parser.add_argument('error_filepath') |
20 |
| - parser.add_argument('output_filepath') |
| 70 | + parser.add_argument("error_filepath") |
| 71 | + parser.add_argument("output_filepath") |
21 | 72 | args = parser.parse_args()
|
22 | 73 | with open(args.error_filepath) as f:
|
23 | 74 | error_message = f.read()
|
24 |
| - with open(args.output_filepath, 'w') as f: |
| 75 | + with open(args.output_filepath, "w") as f: |
25 | 76 | f.write(extract(error_message))
|
26 | 77 |
|
27 | 78 |
|
|
0 commit comments