Skip to content

Commit

Permalink
Merge pull request stanford-oval#102 from zenith110/users/zenith110/s…
Browse files Browse the repository at this point in the history
…erper

Add support for SerperRM.
  • Loading branch information
shaoyijia authored Jul 31, 2024
2 parents 6285589 + e42ba0a commit fe00f38
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/helper/process_kaggle_arxiv_abstract_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
df['description'] = ''

print(f'The downsampled dataset has {len(df)} samples.')
df.to_csv(args.output_path, index=False)
df.to_csv(args.output_path, index=False)
2 changes: 1 addition & 1 deletion examples/run_storm_wiki_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
2 changes: 1 addition & 1 deletion examples/run_storm_wiki_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
2 changes: 1 addition & 1 deletion examples/run_storm_wiki_gpt_with_VectorRM.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,4 @@ def main(args):
help='Top k collected references for each section title.')
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')
main(parser.parse_args())
main(parser.parse_args())
2 changes: 1 addition & 1 deletion examples/run_storm_wiki_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
175 changes: 175 additions & 0 deletions examples/run_storm_wiki_serper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
STORM Wiki pipeline powered by Claude family models and serper search engine.
You need to set up the following environment variables to run this script:
- ANTHROPIC_API_KEY: Anthropic API key
- SERPER_API_KEY: Serper.dev api key
Output will be structured as below
args.output_dir/
topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
conversation_log.json # Log of information-seeking conversation
raw_search_results.json # Raw search results from search engine
direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
storm_gen_outline.txt # Outline refined with collected information
url_to_info.json # Sources that are used in the final article
storm_gen_article.txt # Final article generated
storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
"""

import os
from argparse import ArgumentParser

from knowledge_storm import (
STORMWikiRunnerArguments,
STORMWikiRunner,
STORMWikiLMConfigs,
)
from knowledge_storm.lm import ClaudeModel
from knowledge_storm.rm import SerperRM
from knowledge_storm.utils import load_api_key


def main(args):
load_api_key(toml_file_path="secrets.toml")
lm_configs = STORMWikiLMConfigs()
claude_kwargs = {
"api_key": os.getenv("ANTHROPIC_API_KEY"),
"temperature": 1.0,
"top_p": 0.9,
}

# STORM is a LM system so different components can be powered by different models.
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
# which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
# for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
# which is responsible for generating sections with citations.
conv_simulator_lm = ClaudeModel(
model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs
)
question_asker_lm = ClaudeModel(
model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs
)
outline_gen_lm = ClaudeModel(
model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs
)
article_gen_lm = ClaudeModel(
model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs
)
article_polish_lm = ClaudeModel(
model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs
)

lm_configs.set_conv_simulator_lm(conv_simulator_lm)
lm_configs.set_question_asker_lm(question_asker_lm)
lm_configs.set_outline_gen_lm(outline_gen_lm)
lm_configs.set_article_gen_lm(article_gen_lm)
lm_configs.set_article_polish_lm(article_polish_lm)

engine_args = STORMWikiRunnerArguments(
output_dir=args.output_dir,
max_conv_turn=args.max_conv_turn,
max_perspective=args.max_perspective,
search_top_k=args.search_top_k,
max_thread_num=args.max_thread_num,
)
# Documentation to generate the data is available here:
# https://serper.dev/playground
# Important to note that tbs(date range is hardcoded values).
# num is results per pages and is recommended to use in increments of 10(10, 20, etc).
# page is how many pages will be searched.
# h1 is where the google search will orginate from.
topic = input("topic: ")
data = {"autocorrect": True, "num": 10, "page": 1}
rm = SerperRM(serper_search_api_key=os.getenv("SERPER_API_KEY"), query_params=data)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

runner.run(
topic=topic,
do_research=args.do_research,
do_generate_outline=args.do_generate_outline,
do_generate_article=args.do_generate_article,
do_polish_article=args.do_polish_article,
)
runner.post_run()
runner.summary()


if __name__ == "__main__":
parser = ArgumentParser()
# global arguments
parser.add_argument(
"--output-dir",
type=str,
default="./results/serper",
help="Directory to store the outputs.",
)
parser.add_argument(
"--max-thread-num",
type=int,
default=3,
help="Maximum number of threads to use. The information seeking part and the article generation"
"part can speed up by using multiple threads. Consider reducing it if keep getting "
'"Exceed rate limit" error when calling LM API.',
)
parser.add_argument(
"--retriever",
type=str,
choices=["bing", "you", "serper"],
help="The search engine API to use for retrieving information.",
)
# stage of the pipeline
parser.add_argument(
"--do-research",
action="store_true",
help="If True, simulate conversation to research the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-generate-outline",
action="store_true",
help="If True, generate an outline for the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-generate-article",
action="store_true",
help="If True, generate an article for the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-polish-article",
action="store_true",
help="If True, polish the article by adding a summarization section and (optionally) removing "
"duplicate content.",
)
# hyperparameters for the pre-writing stage
parser.add_argument(
"--max-conv-turn",
type=int,
default=3,
help="Maximum number of questions in conversational question asking.",
)
parser.add_argument(
"--max-perspective",
type=int,
default=3,
help="Maximum number of perspectives to consider in perspective-guided question asking.",
)
parser.add_argument(
"--search-top-k",
type=int,
default=3,
help="Top k search results to consider for each search query.",
)
# hyperparameters for the writing stage
parser.add_argument(
"--retrieve-top-k",
type=int,
default=3,
help="Top k collected references for each section title.",
)
parser.add_argument(
"--remove-duplicate",
action="store_true",
help="If True, remove duplicate content from the article.",
)

main(parser.parse_args())
135 changes: 135 additions & 0 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,138 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
})

return collected_results

class SerperRM(dspy.Retrieve):
"""Retrieve information from custom queries using Serper.dev."""

def __init__(self, serper_search_api_key=None, query_params=None):
"""Args:
serper_search_api_key str: API key to run serper, can be found by creating an account on https://serper.dev/
query_params (dict or list of dict): parameters in dictionary or list of dictionaries that has a max size of 100 that will be used to query.
Commonly used fields are as follows (see more information in https://serper.dev/playground):
q str: query that will be used with google search
type str: type that will be used for browsing google. Types are search, images, video, maps, places, etc.
gl str: Country that will be focused on for the search
location str: Country where the search will originate from. All locates can be found here: https://api.serper.dev/locations.
autocorrect bool: Enable autocorrect on the queries while searching, if query is misspelled, will be updated.
results int: Max number of results per page.
page int: Max number of pages per call.
tbs str: date time range, automatically set to any time by default.
qdr:h str: Date time range for the past hour.
qdr:d str: Date time range for the past 24 hours.
qdr:w str: Date time range for past week.
qdr:m str: Date time range for past month.
qdr:y str: Date time range for past year.
"""
super().__init__()
self.usage = 0
self.query_params = query_params
self.serper_search_api_key = serper_search_api_key
if not self.serper_search_api_key and not os.environ.get('SERPER_API_KEY'):
raise RuntimeError(
'You must supply a serper_search_api_key param or set environment variable SERPER_API_KEY'
)

elif self.serper_search_api_key:
self.serper_search_api_key = serper_search_api_key

else:
self.serper_search_api_key = os.environ['SERPER_API_KEY']

self.base_url = 'https://google.serper.dev'

def serper_runner(self, query_params):
self.search_url = f'{self.base_url}/search'

headers = {
'X-API-KEY': self.serper_search_api_key,
'Content-Type': 'application/json',
}

response = requests.request(
'POST', self.search_url, headers=headers, json=query_params
)

if response == None:
raise RuntimeError(
f'Error had occured while running the search process.\n Error is {response.reason}, had failed with status code {response.status_code}'
)

return response.json()

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0
return {'SerperRM': usage}

def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]):
"""
Calls the API and searches for the query passed in.
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect.
Returns:
a list of dictionaries, each dictionary has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)

self.usage += len(queries)
self.results = []
collected_results = []
for query in queries:
if query == 'Queries:':
continue
query_params = self.query_params

# All available parameters can be found in the playground: https://serper.dev/playground
# Sets the json value for query to be the query that is being parsed.
query_params['q'] = query

# Sets the type to be search, can be images, video, places, maps etc that Google provides.
query_params['type'] = 'search'

self.result = self.serper_runner(query_params)
self.results.append(self.result)

# Array of dictionaries that will be used by Storm to create the jsons
collected_results = []

for result in self.results:
try:
# An array of dictionaries that contains the snippets, title of the document and url that will be used.
organic_results = result.get('organic')

knowledge_graph = result.get('knowledgeGraph')
for organic in organic_results:
snippets = []
snippets.append(organic.get('snippet'))
if knowledge_graph != None:
collected_results.append(
{
'snippets': snippets,
'title': organic.get('title'),
'url': organic.get('link'),
'description': knowledge_graph.get('description'),
}
)
else:
# Common for knowledge graph to be None, set description to empty string
collected_results.append(
{
'snippets': snippets,
'title': result.get('title'),
'url': result.get('link'),
'description': '',
}
)
except:
continue

return collected_results
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ trafilatura
langchain-huggingface
qdrant-client
langchain-qdrant
numpy==1.26.4

0 comments on commit fe00f38

Please sign in to comment.