diff --git a/examples/helper/process_kaggle_arxiv_abstract_dataset.py b/examples/helper/process_kaggle_arxiv_abstract_dataset.py index 4cb885c1..1a07d062 100644 --- a/examples/helper/process_kaggle_arxiv_abstract_dataset.py +++ b/examples/helper/process_kaggle_arxiv_abstract_dataset.py @@ -25,4 +25,4 @@ df['description'] = '' print(f'The downsampled dataset has {len(df)} samples.') - df.to_csv(args.output_path, index=False) + df.to_csv(args.output_path, index=False) \ No newline at end of file diff --git a/examples/run_storm_wiki_claude.py b/examples/run_storm_wiki_claude.py index 31fef1e1..64031f6a 100644 --- a/examples/run_storm_wiki_claude.py +++ b/examples/run_storm_wiki_claude.py @@ -114,4 +114,4 @@ def main(args): parser.add_argument('--remove-duplicate', action='store_true', help='If True, remove duplicate content from the article.') - main(parser.parse_args()) + main(parser.parse_args()) \ No newline at end of file diff --git a/examples/run_storm_wiki_gpt.py b/examples/run_storm_wiki_gpt.py index b7968152..5e4dda3a 100644 --- a/examples/run_storm_wiki_gpt.py +++ b/examples/run_storm_wiki_gpt.py @@ -126,4 +126,4 @@ def main(args): parser.add_argument('--remove-duplicate', action='store_true', help='If True, remove duplicate content from the article.') - main(parser.parse_args()) + main(parser.parse_args()) \ No newline at end of file diff --git a/examples/run_storm_wiki_gpt_with_VectorRM.py b/examples/run_storm_wiki_gpt_with_VectorRM.py index 2c07ffc2..d6662d3c 100644 --- a/examples/run_storm_wiki_gpt_with_VectorRM.py +++ b/examples/run_storm_wiki_gpt_with_VectorRM.py @@ -169,4 +169,4 @@ def main(args): help='Top k collected references for each section title.') parser.add_argument('--remove-duplicate', action='store_true', help='If True, remove duplicate content from the article.') - main(parser.parse_args()) + main(parser.parse_args()) \ No newline at end of file diff --git a/examples/run_storm_wiki_mistral.py b/examples/run_storm_wiki_mistral.py index eb6a4ff6..76f85dc3 100644 --- a/examples/run_storm_wiki_mistral.py +++ b/examples/run_storm_wiki_mistral.py @@ -172,4 +172,4 @@ def main(args): parser.add_argument('--remove-duplicate', action='store_true', help='If True, remove duplicate content from the article.') - main(parser.parse_args()) + main(parser.parse_args()) \ No newline at end of file diff --git a/examples/run_storm_wiki_serper.py b/examples/run_storm_wiki_serper.py new file mode 100644 index 00000000..df27b98b --- /dev/null +++ b/examples/run_storm_wiki_serper.py @@ -0,0 +1,175 @@ +""" +STORM Wiki pipeline powered by Claude family models and serper search engine. +You need to set up the following environment variables to run this script: + - ANTHROPIC_API_KEY: Anthropic API key + - SERPER_API_KEY: Serper.dev api key + +Output will be structured as below +args.output_dir/ + topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash + conversation_log.json # Log of information-seeking conversation + raw_search_results.json # Raw search results from search engine + direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge + storm_gen_outline.txt # Outline refined with collected information + url_to_info.json # Sources that are used in the final article + storm_gen_article.txt # Final article generated + storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) +""" + +import os +from argparse import ArgumentParser + +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) +from knowledge_storm.lm import ClaudeModel +from knowledge_storm.rm import SerperRM +from knowledge_storm.utils import load_api_key + + +def main(args): + load_api_key(toml_file_path="secrets.toml") + lm_configs = STORMWikiLMConfigs() + claude_kwargs = { + "api_key": os.getenv("ANTHROPIC_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, + } + + # STORM is a LM system so different components can be powered by different models. + # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm + # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models + # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm + # which is responsible for generating sections with citations. + conv_simulator_lm = ClaudeModel( + model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs + ) + question_asker_lm = ClaudeModel( + model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs + ) + outline_gen_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs + ) + article_gen_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs + ) + article_polish_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs + ) + + lm_configs.set_conv_simulator_lm(conv_simulator_lm) + lm_configs.set_question_asker_lm(question_asker_lm) + lm_configs.set_outline_gen_lm(outline_gen_lm) + lm_configs.set_article_gen_lm(article_gen_lm) + lm_configs.set_article_polish_lm(article_polish_lm) + + engine_args = STORMWikiRunnerArguments( + output_dir=args.output_dir, + max_conv_turn=args.max_conv_turn, + max_perspective=args.max_perspective, + search_top_k=args.search_top_k, + max_thread_num=args.max_thread_num, + ) + # Documentation to generate the data is available here: + # https://serper.dev/playground + # Important to note that tbs(date range is hardcoded values). + # num is results per pages and is recommended to use in increments of 10(10, 20, etc). + # page is how many pages will be searched. + # h1 is where the google search will orginate from. + topic = input("topic: ") + data = {"autocorrect": True, "num": 10, "page": 1} + rm = SerperRM(serper_search_api_key=os.getenv("SERPER_API_KEY"), query_params=data) + + runner = STORMWikiRunner(engine_args, lm_configs, rm) + + runner.run( + topic=topic, + do_research=args.do_research, + do_generate_outline=args.do_generate_outline, + do_generate_article=args.do_generate_article, + do_polish_article=args.do_polish_article, + ) + runner.post_run() + runner.summary() + + +if __name__ == "__main__": + parser = ArgumentParser() + # global arguments + parser.add_argument( + "--output-dir", + type=str, + default="./results/serper", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "serper"], + help="The search engine API to use for retrieving information.", + ) + # stage of the pipeline + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) + # hyperparameters for the pre-writing stage + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) + # hyperparameters for the writing stage + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) + + main(parser.parse_args()) diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index 86f59703..0951ef33 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -404,3 +404,138 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st }) return collected_results + +class SerperRM(dspy.Retrieve): + """Retrieve information from custom queries using Serper.dev.""" + + def __init__(self, serper_search_api_key=None, query_params=None): + """Args: + serper_search_api_key str: API key to run serper, can be found by creating an account on https://serper.dev/ + query_params (dict or list of dict): parameters in dictionary or list of dictionaries that has a max size of 100 that will be used to query. + Commonly used fields are as follows (see more information in https://serper.dev/playground): + q str: query that will be used with google search + type str: type that will be used for browsing google. Types are search, images, video, maps, places, etc. + gl str: Country that will be focused on for the search + location str: Country where the search will originate from. All locates can be found here: https://api.serper.dev/locations. + autocorrect bool: Enable autocorrect on the queries while searching, if query is misspelled, will be updated. + results int: Max number of results per page. + page int: Max number of pages per call. + tbs str: date time range, automatically set to any time by default. + qdr:h str: Date time range for the past hour. + qdr:d str: Date time range for the past 24 hours. + qdr:w str: Date time range for past week. + qdr:m str: Date time range for past month. + qdr:y str: Date time range for past year. + """ + super().__init__() + self.usage = 0 + self.query_params = query_params + self.serper_search_api_key = serper_search_api_key + if not self.serper_search_api_key and not os.environ.get('SERPER_API_KEY'): + raise RuntimeError( + 'You must supply a serper_search_api_key param or set environment variable SERPER_API_KEY' + ) + + elif self.serper_search_api_key: + self.serper_search_api_key = serper_search_api_key + + else: + self.serper_search_api_key = os.environ['SERPER_API_KEY'] + + self.base_url = 'https://google.serper.dev' + + def serper_runner(self, query_params): + self.search_url = f'{self.base_url}/search' + + headers = { + 'X-API-KEY': self.serper_search_api_key, + 'Content-Type': 'application/json', + } + + response = requests.request( + 'POST', self.search_url, headers=headers, json=query_params + ) + + if response == None: + raise RuntimeError( + f'Error had occured while running the search process.\n Error is {response.reason}, had failed with status code {response.status_code}' + ) + + return response.json() + + def get_usage_and_reset(self): + usage = self.usage + self.usage = 0 + return {'SerperRM': usage} + + def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]): + """ + Calls the API and searches for the query passed in. + + + Args: + query_or_queries (Union[str, List[str]]): The query or queries to search for. + exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect. + + Returns: + a list of dictionaries, each dictionary has keys of 'description', 'snippets' (list of strings), 'title', 'url' + """ + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + + self.usage += len(queries) + self.results = [] + collected_results = [] + for query in queries: + if query == 'Queries:': + continue + query_params = self.query_params + + # All available parameters can be found in the playground: https://serper.dev/playground + # Sets the json value for query to be the query that is being parsed. + query_params['q'] = query + + # Sets the type to be search, can be images, video, places, maps etc that Google provides. + query_params['type'] = 'search' + + self.result = self.serper_runner(query_params) + self.results.append(self.result) + + # Array of dictionaries that will be used by Storm to create the jsons + collected_results = [] + + for result in self.results: + try: + # An array of dictionaries that contains the snippets, title of the document and url that will be used. + organic_results = result.get('organic') + + knowledge_graph = result.get('knowledgeGraph') + for organic in organic_results: + snippets = [] + snippets.append(organic.get('snippet')) + if knowledge_graph != None: + collected_results.append( + { + 'snippets': snippets, + 'title': organic.get('title'), + 'url': organic.get('link'), + 'description': knowledge_graph.get('description'), + } + ) + else: + # Common for knowledge graph to be None, set description to empty string + collected_results.append( + { + 'snippets': snippets, + 'title': result.get('title'), + 'url': result.get('link'), + 'description': '', + } + ) + except: + continue + + return collected_results diff --git a/requirements.txt b/requirements.txt index 8ac1b95b..79ee9d68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ trafilatura langchain-huggingface qdrant-client langchain-qdrant +numpy==1.26.4