Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New RM Serper #102

Merged
merged 9 commits into from
Jul 31, 2024
2 changes: 1 addition & 1 deletion examples/helper/process_kaggle_arxiv_abstract_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
df['description'] = ''

print(f'The downsampled dataset has {len(df)} samples.')
df.to_csv(args.output_path, index=False)
df.to_csv(args.output_path, index=False)
zenith110 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion examples/run_storm_wiki_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
zenith110 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion examples/run_storm_wiki_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
2 changes: 1 addition & 1 deletion examples/run_storm_wiki_gpt_with_VectorRM.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,4 @@ def main(args):
help='Top k collected references for each section title.')
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')
main(parser.parse_args())
main(parser.parse_args())
zenith110 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion examples/run_storm_wiki_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
zenith110 marked this conversation as resolved.
Show resolved Hide resolved
175 changes: 175 additions & 0 deletions examples/run_storm_wiki_serper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
STORM Wiki pipeline powered by Claude family models and serper search engine.
You need to set up the following environment variables to run this script:
- ANTHROPIC_API_KEY: Anthropic API key
- SERPER_API_KEY: Serper.dev api key

Output will be structured as below
args.output_dir/
topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
conversation_log.json # Log of information-seeking conversation
raw_search_results.json # Raw search results from search engine
direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
storm_gen_outline.txt # Outline refined with collected information
url_to_info.json # Sources that are used in the final article
storm_gen_article.txt # Final article generated
storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
"""

import os
from argparse import ArgumentParser

from knowledge_storm import (
STORMWikiRunnerArguments,
STORMWikiRunner,
STORMWikiLMConfigs,
)
from knowledge_storm.lm import ClaudeModel
from knowledge_storm.rm import SerperRM
from knowledge_storm.utils import load_api_key


def main(args):
load_api_key(toml_file_path="secrets.toml")
lm_configs = STORMWikiLMConfigs()
claude_kwargs = {
"api_key": os.getenv("ANTHROPIC_API_KEY"),
"temperature": 1.0,
"top_p": 0.9,
}

# STORM is a LM system so different components can be powered by different models.
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
# which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
# for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
# which is responsible for generating sections with citations.
conv_simulator_lm = ClaudeModel(
model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs
)
question_asker_lm = ClaudeModel(
model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs
)
outline_gen_lm = ClaudeModel(
model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs
)
article_gen_lm = ClaudeModel(
model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs
)
article_polish_lm = ClaudeModel(
model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs
)

lm_configs.set_conv_simulator_lm(conv_simulator_lm)
lm_configs.set_question_asker_lm(question_asker_lm)
lm_configs.set_outline_gen_lm(outline_gen_lm)
lm_configs.set_article_gen_lm(article_gen_lm)
lm_configs.set_article_polish_lm(article_polish_lm)

engine_args = STORMWikiRunnerArguments(
output_dir=args.output_dir,
max_conv_turn=args.max_conv_turn,
max_perspective=args.max_perspective,
search_top_k=args.search_top_k,
max_thread_num=args.max_thread_num,
)
# Documentation to generate the data is available here:
# https://serper.dev/playground
# Important to note that tbs(date range is hardcoded values).
# num is results per pages and is recommended to use in increments of 10(10, 20, etc).
# page is how many pages will be searched.
# h1 is where the google search will orginate from.
topic = input("topic: ")
data = {"autocorrect": True, "num": 10, "page": 1}
rm = SerperRM(serper_search_api_key=os.getenv("SERPER_API_KEY"), query_params=data)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

runner.run(
topic=topic,
do_research=args.do_research,
do_generate_outline=args.do_generate_outline,
do_generate_article=args.do_generate_article,
do_polish_article=args.do_polish_article,
)
runner.post_run()
runner.summary()


if __name__ == "__main__":
parser = ArgumentParser()
# global arguments
parser.add_argument(
"--output-dir",
type=str,
default="./results/serper",
help="Directory to store the outputs.",
)
parser.add_argument(
"--max-thread-num",
type=int,
default=3,
help="Maximum number of threads to use. The information seeking part and the article generation"
"part can speed up by using multiple threads. Consider reducing it if keep getting "
'"Exceed rate limit" error when calling LM API.',
)
parser.add_argument(
"--retriever",
type=str,
choices=["bing", "you", "serper"],
help="The search engine API to use for retrieving information.",
)
# stage of the pipeline
parser.add_argument(
"--do-research",
action="store_true",
help="If True, simulate conversation to research the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-generate-outline",
action="store_true",
help="If True, generate an outline for the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-generate-article",
action="store_true",
help="If True, generate an article for the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-polish-article",
action="store_true",
help="If True, polish the article by adding a summarization section and (optionally) removing "
"duplicate content.",
)
# hyperparameters for the pre-writing stage
parser.add_argument(
"--max-conv-turn",
type=int,
default=3,
help="Maximum number of questions in conversational question asking.",
)
parser.add_argument(
"--max-perspective",
type=int,
default=3,
help="Maximum number of perspectives to consider in perspective-guided question asking.",
)
parser.add_argument(
"--search-top-k",
type=int,
default=3,
help="Top k search results to consider for each search query.",
)
# hyperparameters for the writing stage
parser.add_argument(
"--retrieve-top-k",
type=int,
default=3,
help="Top k collected references for each section title.",
)
parser.add_argument(
"--remove-duplicate",
action="store_true",
help="If True, remove duplicate content from the article.",
)

main(parser.parse_args())
135 changes: 135 additions & 0 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,138 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
})

return collected_results

class SerperRM(dspy.Retrieve):
"""Retrieve information from custom queries using Serper.dev."""

def __init__(self, serper_search_api_key=None, query_params=None):
"""Args:
serper_search_api_key str: API key to run serper, can be found by creating an account on https://serper.dev/
query_params (dict or list of dict): parameters in dictionary or list of dictionaries that has a max size of 100 that will be used to query.
Commonly used fields are as follows (see more information in https://serper.dev/playground):
q str: query that will be used with google search
type str: type that will be used for browsing google. Types are search, images, video, maps, places, etc.
gl str: Country that will be focused on for the search
location str: Country where the search will originate from. All locates can be found here: https://api.serper.dev/locations.
autocorrect bool: Enable autocorrect on the queries while searching, if query is misspelled, will be updated.
results int: Max number of results per page.
page int: Max number of pages per call.
tbs str: date time range, automatically set to any time by default.
qdr:h str: Date time range for the past hour.
qdr:d str: Date time range for the past 24 hours.
qdr:w str: Date time range for past week.
qdr:m str: Date time range for past month.
qdr:y str: Date time range for past year.
"""
super().__init__()
self.usage = 0
self.query_params = query_params
self.serper_search_api_key = serper_search_api_key
if not self.serper_search_api_key and not os.environ.get('SERPER_API_KEY'):
raise RuntimeError(
'You must supply a serper_search_api_key param or set environment variable SERPER_API_KEY'
)

elif self.serper_search_api_key:
self.serper_search_api_key = serper_search_api_key

else:
self.serper_search_api_key = os.environ['SERPER_API_KEY']

self.base_url = 'https://google.serper.dev'

def serper_runner(self, query_params):
self.search_url = f'{self.base_url}/search'

headers = {
'X-API-KEY': self.serper_search_api_key,
'Content-Type': 'application/json',
}

response = requests.request(
'POST', self.search_url, headers=headers, json=query_params
)

if response == None:
raise RuntimeError(
f'Error had occured while running the search process.\n Error is {response.reason}, had failed with status code {response.status_code}'
)

return response.json()

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0
return {'SerperRM': usage}

def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]):
"""
Calls the API and searches for the query passed in.


Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect.

Returns:
a list of dictionaries, each dictionary has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)

self.usage += len(queries)
self.results = []
collected_results = []
for query in queries:
if query == 'Queries:':
continue
query_params = self.query_params
zenith110 marked this conversation as resolved.
Show resolved Hide resolved

# All available parameters can be found in the playground: https://serper.dev/playground
# Sets the json value for query to be the query that is being parsed.
query_params['q'] = query

# Sets the type to be search, can be images, video, places, maps etc that Google provides.
query_params['type'] = 'search'

self.result = self.serper_runner(query_params)
self.results.append(self.result)

# Array of dictionaries that will be used by Storm to create the jsons
collected_results = []

for result in self.results:
try:
# An array of dictionaries that contains the snippets, title of the document and url that will be used.
organic_results = result.get('organic')

knowledge_graph = result.get('knowledgeGraph')
for organic in organic_results:
snippets = []
snippets.append(organic.get('snippet'))
if knowledge_graph != None:
collected_results.append(
{
'snippets': snippets,
'title': organic.get('title'),
'url': organic.get('link'),
'description': knowledge_graph.get('description'),
}
)
else:
# Common for knowledge graph to be None, set description to empty string
collected_results.append(
{
'snippets': snippets,
'title': result.get('title'),
'url': result.get('link'),
'description': '',
}
)
except:
continue

return collected_results
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ trafilatura
langchain-huggingface
qdrant-client
langchain-qdrant
numpy==1.26.4