From 1c6dba5d2e75a90262198e649217d86151c43f6f Mon Sep 17 00:00:00 2001 From: zenith110 Date: Sun, 21 Jul 2024 21:17:35 -0400 Subject: [PATCH 1/7] Got the basic framework set up for serper --- knowledge_storm/rm.py | 71 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index 86f59703..2c5fdf57 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -1,6 +1,7 @@ import logging import os -from typing import Callable, Union, List +from typing import Callable, Union, List, Dict +from typing_extensions import Dict import dspy import pandas as pd @@ -10,7 +11,7 @@ from langchain_qdrant import Qdrant from qdrant_client import QdrantClient, models from tqdm import tqdm - +import requests from .utils import WebPageHelper @@ -404,3 +405,69 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st }) return collected_results + +class SerperRM(dspy.Retrieve) + def __init__(self, serper_search_api_key=None, query_params=None, minibatch=None): + super().__init__() + if not self.serper_search_api_key and not os.environ.get("SERPER_API_KEY"): + raise RuntimeError( + "You must supply serper_search_api_key or set environment variable SERPER_API_KEY") + elif self.serper_search_api_key: + self.serper_search_api_key = serper_search_api_key + else: + self.serper_search_api_key = os.environ["SERPER_API_KEY"] + self.base_url = "https://google.serper.dev" + def validate_input(self): + if(self.query_params.get("type") == None or type(self.query_params.get("type")) != str): + raise RuntimeError("A type must be provided.") + elif(len(self.query_params.get("gl")) > 2 or self.query_params.get("gl") == None or type(self.query_params.get("gl")) != str): + raise RuntimeError("Country code was not provided.") + elif(self.query_params.get("hl") == None or type(self.query_params.get("hl")) != str): + raise RuntimeError("Language was not provided.") + elif(self.query_params.get("tbs") == None or type(self.query_params.get("tbs")) != str): + raise RuntimeError("Date range was not provided.") + elif(self.query_params.get("autocorrect") == None or type(self.query_params.get("autocorrect")) != bool): + raise RuntimeError("Autocorrect boolean was not provided.") + elif(self.query_params.get("results") == None or type(self.query_params.get("results")) != int): + raise RuntimeError("Number of results returned was not provided.") + elif(self.query_params.get("page") == None or type(self.query_params.get("page")) != int): + raise RuntimeError("Number of pages to return is not provided.") + elif(self.query_params.get("page") == None or type(self.query_params.get("page")) != int): + raise RuntimeError("Number of pages to return is not provided.") + elif(self.minibatch == None or type(self.query_params) == Dict): + raise RuntimeError("Minibatch is enabled, however query_params is a dictionary, will need to be converted to list to be able to be used.") + def runner(self): + match self.query_params.get("type"): + case "search": + self.run_process("type") + case "images": + self.run_process("images") + case "videos": + self.run_process("videos") + case "places": + self.run_process("places") + case "maps": + self.run_process("maps") + case "news": + self.run_process("news") + case "shopping" + self.run_process("shopping") + case "scholar" + self.run_process("scholar") + case "patents" + self.run_process("patents") + case "autocomplete": + self.run_process("autocomplete") + case "" + def run_process(self, process_name=None): + print(f"Beginning {process_name} process...") + self.search_url = f"{self.base_url}/{process_name}" + headers = { + "X-API-KEY": self.serper_search_api_key, + "Content-Type": "application/json" + } + + response = requests.request("POST", self.search_url, headers=headers, data=self.query_params) + if response == None: + raise RuntimeError(f"Error had occured while running the process {process_name}.\n Error is {response.reason}, had failed with status code {response.status_code}") + return response.json() From 0b5a56342a93256155a6f2cdb51ead74be393446 Mon Sep 17 00:00:00 2001 From: zenith110 Date: Sat, 27 Jul 2024 17:09:28 -0400 Subject: [PATCH 2/7] Added serper --- .../process_kaggle_arxiv_abstract_dataset.py | 23 +- examples/run_storm_wiki_claude.py | 145 +++++--- examples/run_storm_wiki_deepseek.py | 173 +++++++--- examples/run_storm_wiki_gpt.py | 153 +++++--- examples/run_storm_wiki_gpt_with_VectorRM.py | 224 ++++++++---- examples/run_storm_wiki_mistral.py | 180 ++++++---- examples/run_storm_wiki_ollama.py | 189 ++++++---- examples/run_storm_wiki_serper.py | 175 ++++++++++ knowledge_storm/__init__.py | 2 +- knowledge_storm/interface.py | 69 ++-- knowledge_storm/lm.py | 228 +++++++----- knowledge_storm/rm.py | 326 +++++++++++------- knowledge_storm/storm_wiki/engine.py | 310 +++++++++++------ .../storm_wiki/modules/article_generation.py | 117 ++++--- .../storm_wiki/modules/article_polish.py | 46 ++- .../storm_wiki/modules/knowledge_curation.py | 249 ++++++++----- .../storm_wiki/modules/outline_generation.py | 106 +++--- .../storm_wiki/modules/persona_generator.py | 64 ++-- .../storm_wiki/modules/retriever.py | 23 +- .../storm_wiki/modules/storm_dataclass.py | 233 ++++++++----- knowledge_storm/utils.py | 166 +++++---- requirements.txt | 1 + 22 files changed, 2148 insertions(+), 1054 deletions(-) create mode 100644 examples/run_storm_wiki_serper.py diff --git a/examples/helper/process_kaggle_arxiv_abstract_dataset.py b/examples/helper/process_kaggle_arxiv_abstract_dataset.py index 4cb885c1..30583c4c 100644 --- a/examples/helper/process_kaggle_arxiv_abstract_dataset.py +++ b/examples/helper/process_kaggle_arxiv_abstract_dataset.py @@ -8,21 +8,28 @@ if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--input-path", type=str, help="Path to arxiv_data_210930-054931.csv.") - parser.add_argument("--output-path", type=str, - help="Path to store the csv file that is compatible with VectorRM.") + parser.add_argument( + "--input-path", type=str, help="Path to arxiv_data_210930-054931.csv." + ) + parser.add_argument( + "--output-path", + type=str, + help="Path to store the csv file that is compatible with VectorRM.", + ) args = parser.parse_args() df = pd.read_csv(args.input_path) - print(f'The original dataset has {len(df)} samples.') + print(f"The original dataset has {len(df)} samples.") # Downsample the dataset. - df = df[df['terms'] == "['cs.CV']"] + df = df[df["terms"] == "['cs.CV']"] # Reformat the dataset to match the VectorRM input format. df.rename(columns={"abstracts": "content", "titles": "title"}, inplace=True) - df['url'] = ['uid_' + str(idx) for idx in range(len(df))] # Ensure the url is unique. - df['description'] = '' + df["url"] = [ + "uid_" + str(idx) for idx in range(len(df)) + ] # Ensure the url is unique. + df["description"] = "" - print(f'The downsampled dataset has {len(df)} samples.') + print(f"The downsampled dataset has {len(df)} samples.") df.to_csv(args.output_path, index=False) diff --git a/examples/run_storm_wiki_claude.py b/examples/run_storm_wiki_claude.py index 31fef1e1..3d6847cb 100644 --- a/examples/run_storm_wiki_claude.py +++ b/examples/run_storm_wiki_claude.py @@ -19,19 +19,23 @@ import os from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import ClaudeModel from knowledge_storm.rm import YouRM, BingSearch from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() claude_kwargs = { - 'api_key': os.getenv("ANTHROPIC_API_KEY"), - 'temperature': 1.0, - 'top_p': 0.9 + "api_key": os.getenv("ANTHROPIC_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, } # STORM is a LM system so different components can be powered by different models. @@ -39,11 +43,21 @@ def main(args): # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ClaudeModel(model='claude-3-haiku-20240307', max_tokens=500, **claude_kwargs) - question_asker_lm = ClaudeModel(model='claude-3-sonnet-20240229', max_tokens=500, **claude_kwargs) - outline_gen_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=400, **claude_kwargs) - article_gen_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=700, **claude_kwargs) - article_polish_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=4000, **claude_kwargs) + conv_simulator_lm = ClaudeModel( + model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs + ) + question_asker_lm = ClaudeModel( + model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs + ) + outline_gen_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs + ) + article_gen_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs + ) + article_polish_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -61,14 +75,16 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. - if args.retriever == 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - elif args.retriever == 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) + if args.retriever == "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), k=engine_args.search_top_k + ) + elif args.retriever == "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -80,38 +96,81 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/claude', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/claude", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) main(parser.parse_args()) diff --git a/examples/run_storm_wiki_deepseek.py b/examples/run_storm_wiki_deepseek.py index 2a7b2566..d159e948 100644 --- a/examples/run_storm_wiki_deepseek.py +++ b/examples/run_storm_wiki_deepseek.py @@ -23,7 +23,11 @@ import logging from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import DeepSeekModel from knowledge_storm.rm import YouRM, BingSearch from knowledge_storm.utils import load_api_key @@ -35,10 +39,10 @@ def sanitize_topic(topic): Remove or replace characters that are not allowed in file names. """ # Replace spaces with underscores - topic = topic.replace(' ', '_') + topic = topic.replace(" ", "_") # Remove any character that isn't alphanumeric, underscore, or hyphen - topic = re.sub(r'[^a-zA-Z0-9_-]', '', topic) + topic = re.sub(r"[^a-zA-Z0-9_-]", "", topic) # Ensure the topic isn't empty after sanitization if not topic: @@ -48,27 +52,35 @@ def sanitize_topic(topic): def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() # Ensure DEEPSEEK_API_KEY is set if not os.getenv("DEEPSEEK_API_KEY"): - raise ValueError("DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file.") + raise ValueError( + "DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file." + ) deepseek_kwargs = { - 'api_key': os.getenv("DEEPSEEK_API_KEY"), - 'api_base': os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"), - 'temperature': args.temperature, - 'top_p': args.top_p, + "api_key": os.getenv("DEEPSEEK_API_KEY"), + "api_base": os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"), + "temperature": args.temperature, + "top_p": args.top_p, } # DeepSeek offers two main models: 'deepseek-chat' for general tasks and 'deepseek-coder' for coding tasks # Users can choose the appropriate model based on their needs - conv_simulator_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs) - question_asker_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs) + conv_simulator_lm = DeepSeekModel( + model=args.model, max_tokens=500, **deepseek_kwargs + ) + question_asker_lm = DeepSeekModel( + model=args.model, max_tokens=500, **deepseek_kwargs + ) outline_gen_lm = DeepSeekModel(model=args.model, max_tokens=400, **deepseek_kwargs) article_gen_lm = DeepSeekModel(model=args.model, max_tokens=700, **deepseek_kwargs) - article_polish_lm = DeepSeekModel(model=args.model, max_tokens=4000, **deepseek_kwargs) + article_polish_lm = DeepSeekModel( + model=args.model, max_tokens=4000, **deepseek_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -86,16 +98,20 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. - if args.retriever == 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - elif args.retriever == 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) + if args.retriever == "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), k=engine_args.search_top_k + ) + elif args.retriever == "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) else: - raise ValueError(f"Invalid retriever: {args.retriever}. Choose either 'bing' or 'you'.") + raise ValueError( + f"Invalid retriever: {args.retriever}. Choose either 'bing' or 'you'." + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") sanitized_topic = sanitize_topic(topic) try: @@ -114,44 +130,95 @@ def main(args): raise -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/deepseek', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you'], required=True, - help='The search engine API to use for retrieving information.') - parser.add_argument('--model', type=str, choices=['deepseek-chat', 'deepseek-coder'], default='deepseek-chat', - help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.') - parser.add_argument('--temperature', type=float, default=1.0, - help='Sampling temperature to use.') - parser.add_argument('--top_p', type=float, default=0.9, - help='Top-p sampling parameter.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/deepseek", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you"], + required=True, + help="The search engine API to use for retrieving information.", + ) + parser.add_argument( + "--model", + type=str, + choices=["deepseek-chat", "deepseek-coder"], + default="deepseek-chat", + help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.', + ) + parser.add_argument( + "--temperature", type=float, default=1.0, help="Sampling temperature to use." + ) + parser.add_argument( + "--top_p", type=float, default=0.9, help="Top-p sampling parameter." + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/run_storm_wiki_gpt.py b/examples/run_storm_wiki_gpt.py index b7968152..b97c1c47 100644 --- a/examples/run_storm_wiki_gpt.py +++ b/examples/run_storm_wiki_gpt.py @@ -22,40 +22,54 @@ import os from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel from knowledge_storm.rm import YouRM, BingSearch from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() openai_kwargs = { - 'api_key': os.getenv("OPENAI_API_KEY"), - 'temperature': 1.0, - 'top_p': 0.9, + "api_key": os.getenv("OPENAI_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, } - ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel + ModelClass = ( + OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel + ) # If you are using Azure service, make sure the model name matches your own deployed model name. # The default name here is only used for demonstration and may not match your case. - gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo' - gpt_4_model_name = 'gpt-4o' - if os.getenv('OPENAI_API_TYPE') == 'azure': - openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') - openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') + gpt_35_model_name = ( + "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo" + ) + gpt_4_model_name = "gpt-4o" + if os.getenv("OPENAI_API_TYPE") == "azure": + openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") + openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") # STORM is a LM system so different components can be powered by different models. # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) - question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + conv_simulator_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) + question_asker_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) - article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs) + article_polish_lm = ModelClass( + model=gpt_4_model_name, max_tokens=4000, **openai_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -73,14 +87,16 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. - if args.retriever == 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - elif args.retriever == 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) + if args.retriever == "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), k=engine_args.search_top_k + ) + elif args.retriever == "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -92,38 +108,81 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/gpt', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/gpt", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) main(parser.parse_args()) diff --git a/examples/run_storm_wiki_gpt_with_VectorRM.py b/examples/run_storm_wiki_gpt_with_VectorRM.py index 2c07ffc2..6eed8444 100644 --- a/examples/run_storm_wiki_gpt_with_VectorRM.py +++ b/examples/run_storm_wiki_gpt_with_VectorRM.py @@ -30,7 +30,11 @@ import sys from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.rm import VectorRM from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel from knowledge_storm.utils import load_api_key @@ -38,35 +42,45 @@ def main(args): # Load API key from the specified toml file path - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") # Initialize the language model configurations engine_lm_configs = STORMWikiLMConfigs() openai_kwargs = { - 'api_key': os.getenv("OPENAI_API_KEY"), - 'temperature': 1.0, - 'top_p': 0.9, + "api_key": os.getenv("OPENAI_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, } - ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel + ModelClass = ( + OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel + ) # If you are using Azure service, make sure the model name matches your own deployed model name. # The default name here is only used for demonstration and may not match your case. - gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo' - gpt_4_model_name = 'gpt-4o' - if os.getenv('OPENAI_API_TYPE') == 'azure': - openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') - openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') + gpt_35_model_name = ( + "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo" + ) + gpt_4_model_name = "gpt-4o" + if os.getenv("OPENAI_API_TYPE") == "azure": + openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") + openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") # STORM is a LM system so different components can be powered by different models. - # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm + # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) - question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + conv_simulator_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) + question_asker_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) - article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs) + article_polish_lm = ModelClass( + model=gpt_4_model_name, max_tokens=4000, **openai_kwargs + ) engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm) engine_lm_configs.set_question_asker_lm(question_asker_lm) @@ -84,30 +98,36 @@ def main(args): ) # Setup VectorRM to retrieve information from your own data - rm = VectorRM(collection_name=args.collection_name, device=args.device, k=engine_args.search_top_k) + rm = VectorRM( + collection_name=args.collection_name, + device=args.device, + k=engine_args.search_top_k, + ) # initialize the vector store, either online (store the db on Qdrant server) or offline (store the db locally): - if args.vector_db_mode == 'offline': + if args.vector_db_mode == "offline": rm.init_offline_vector_db(vector_store_path=args.offline_vector_db_dir) - elif args.vector_db_mode == 'online': - rm.init_online_vector_db(url=args.online_vector_db_url, api_key=os.getenv('QDRANT_API_KEY')) + elif args.vector_db_mode == "online": + rm.init_online_vector_db( + url=args.online_vector_db_url, api_key=os.getenv("QDRANT_API_KEY") + ) # Update the vector store with the documents in the csv file if args.update_vector_store: rm.update_vector_store( file_path=args.csv_file_path, - content_column='content', - title_column='title', - url_column='url', - desc_column='description', - batch_size=args.embed_batch_size + content_column="content", + title_column="title", + url_column="url", + desc_column="description", + batch_size=args.embed_batch_size, ) # Initialize the STORM Wiki Runner runner = STORMWikiRunner(engine_args, engine_lm_configs, rm) # run the pipeline - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -122,51 +142,119 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/gpt_retrieval', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/gpt_retrieval", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) # provide local corpus and set up vector db - parser.add_argument('--collection-name', type=str, default="my_documents", - help='The collection name for vector store.') - parser.add_argument('--device', type=str, default="mps", - help='The device used to run the retrieval model (mps, cuda, cpu, etc).') - parser.add_argument('--vector-db-mode', type=str, choices=['offline', 'online'], - help='The mode of the Qdrant vector store (offline or online).') - parser.add_argument('--offline-vector-db-dir', type=str, default='./vector_store', - help='If use offline mode, please provide the directory to store the vector store.') - parser.add_argument('--online-vector-db-url', type=str, - help='If use online mode, please provide the url of the Qdrant server.') - parser.add_argument('--update-vector-store', action='store_true', - help='If True, update the vector store with the documents in the csv file; otherwise, ' - 'use the existing vector store.') - parser.add_argument('--csv-file-path', type=str, - help='The path of the custom document corpus in CSV format. The CSV file should include ' - 'content, title, url, and description columns.') - parser.add_argument('--embed-batch-size', type=int, default=64, - help='Batch size for embedding the documents in the csv file.') + parser.add_argument( + "--collection-name", + type=str, + default="my_documents", + help="The collection name for vector store.", + ) + parser.add_argument( + "--device", + type=str, + default="mps", + help="The device used to run the retrieval model (mps, cuda, cpu, etc).", + ) + parser.add_argument( + "--vector-db-mode", + type=str, + choices=["offline", "online"], + help="The mode of the Qdrant vector store (offline or online).", + ) + parser.add_argument( + "--offline-vector-db-dir", + type=str, + default="./vector_store", + help="If use offline mode, please provide the directory to store the vector store.", + ) + parser.add_argument( + "--online-vector-db-url", + type=str, + help="If use online mode, please provide the url of the Qdrant server.", + ) + parser.add_argument( + "--update-vector-store", + action="store_true", + help="If True, update the vector store with the documents in the csv file; otherwise, " + "use the existing vector store.", + ) + parser.add_argument( + "--csv-file-path", + type=str, + help="The path of the custom document corpus in CSV format. The CSV file should include " + "content, title, url, and description columns.", + ) + parser.add_argument( + "--embed-batch-size", + type=int, + default=64, + help="Batch size for embedding the documents in the csv file.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) main(parser.parse_args()) diff --git a/examples/run_storm_wiki_mistral.py b/examples/run_storm_wiki_mistral.py index eb6a4ff6..291d2879 100644 --- a/examples/run_storm_wiki_mistral.py +++ b/examples/run_storm_wiki_mistral.py @@ -15,26 +15,33 @@ storm_gen_article.txt # Final article generated storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) """ + import os from argparse import ArgumentParser from dspy import Example -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import VLLMClient from knowledge_storm.rm import YouRM, BingSearch from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() mistral_kwargs = { "model": "mistralai/Mistral-7B-Instruct-v0.2", "port": args.port, "url": args.url, - "stop": ('\n\n---',) # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. + "stop": ( + "\n\n---", + ), # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. } conv_simulator_lm = VLLMClient(max_tokens=500, **mistral_kwargs) @@ -59,10 +66,12 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. - if args.retriever == 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - elif args.retriever == 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) + if args.retriever == "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), k=engine_args.search_top_k + ) + elif args.retriever == "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) runner = STORMWikiRunner(engine_args, lm_configs, rm) @@ -74,26 +83,28 @@ def main(args): find_related_topic_example = Example( topic="Knowledge Curation", related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" - "https://en.wikipedia.org/wiki/Information_science\n" - "https://en.wikipedia.org/wiki/Library_science\n" + "https://en.wikipedia.org/wiki/Information_science\n" + "https://en.wikipedia.org/wiki/Library_science\n", ) gen_persona_example = Example( topic="Knowledge Curation", examples="Title: Knowledge management\n" - "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" - "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" - " Knowledge protection methods\n Formal methods\n Informal methods\n" - " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", + "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" + "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" + " Knowledge protection methods\n Formal methods\n Informal methods\n" + " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", personas="1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\n" - "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" - "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" - "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" - "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm." + "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" + "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" + "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" + "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.", ) runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ - find_related_topic_example] + find_related_topic_example + ] runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ - gen_persona_example] + gen_persona_example + ] # A trade-off of adding one-shot example is that it will increase the input length of the prompt. Also, some # examples may be very long (e.g., an example for writing a section based on the given information), which may @@ -105,24 +116,28 @@ def main(args): topic="Example Topic", conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" - "# Section 2\n## Subsection 1\n## Subsection 2\n" - "# Section 3", + "# Section 2\n## Subsection 1\n## Subsection 2\n" + "# Section 3", outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" - "# New Section 2\n" - "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3" + "# New Section 2\n" + "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", ) - runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [write_page_outline_example] + runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ + write_page_outline_example + ] write_section_example = Example( info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", topic="Example Topic", section="Example Section", output="# Example Topic\n## Subsection 1\n" - "This is an example sentence [1]. This is another example sentence [2][3].\n" - "## Subsection 2\nThis is one more example sentence [1]." + "This is an example sentence [1]. This is another example sentence [2][3].\n" + "## Subsection 2\nThis is one more example sentence [1].", ) - runner.storm_article_generation.section_gen.write_section.demos = [write_section_example] + runner.storm_article_generation.section_gen.write_section.demos = [ + write_section_example + ] - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -134,42 +149,87 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--url', type=str, default='http://localhost', - help='URL of the VLLM server.') - parser.add_argument('--port', type=int, default=8000, - help='Port of the VLLM server.') - parser.add_argument('--output-dir', type=str, default='./results/mistral_7b', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--url", type=str, default="http://localhost", help="URL of the VLLM server." + ) + parser.add_argument( + "--port", type=int, default=8000, help="Port of the VLLM server." + ) + parser.add_argument( + "--output-dir", + type=str, + default="./results/mistral_7b", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) main(parser.parse_args()) diff --git a/examples/run_storm_wiki_ollama.py b/examples/run_storm_wiki_ollama.py index 35ba99e1..2e930464 100644 --- a/examples/run_storm_wiki_ollama.py +++ b/examples/run_storm_wiki_ollama.py @@ -15,28 +15,35 @@ storm_gen_article.txt # Final article generated storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) """ + import os import sys from argparse import ArgumentParser from dspy import Example -sys.path.append('./src') +sys.path.append("./src") from lm import OllamaClient from rm import YouRM, BingSearch -from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from storm_wiki.engine import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() ollama_kwargs = { "model": args.model, "port": args.port, "url": args.url, - "stop": ('\n\n---',) # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. + "stop": ( + "\n\n---", + ), # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. } conv_simulator_lm = OllamaClient(max_tokens=500, **ollama_kwargs) @@ -61,10 +68,12 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. - if args.retriever == 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - elif args.retriever == 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) + if args.retriever == "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), k=engine_args.search_top_k + ) + elif args.retriever == "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) runner = STORMWikiRunner(engine_args, lm_configs, rm) @@ -76,26 +85,28 @@ def main(args): find_related_topic_example = Example( topic="Knowledge Curation", related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" - "https://en.wikipedia.org/wiki/Information_science\n" - "https://en.wikipedia.org/wiki/Library_science\n" + "https://en.wikipedia.org/wiki/Information_science\n" + "https://en.wikipedia.org/wiki/Library_science\n", ) gen_persona_example = Example( topic="Knowledge Curation", examples="Title: Knowledge management\n" - "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" - "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" - " Knowledge protection methods\n Formal methods\n Informal methods\n" - " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", + "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" + "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" + " Knowledge protection methods\n Formal methods\n Informal methods\n" + " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", personas="1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\n" - "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" - "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" - "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" - "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm." + "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" + "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" + "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" + "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.", ) runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ - find_related_topic_example] + find_related_topic_example + ] runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ - gen_persona_example] + gen_persona_example + ] # A trade-off of adding one-shot example is that it will increase the input length of the prompt. Also, some # examples may be very long (e.g., an example for writing a section based on the given information), which may @@ -107,24 +118,28 @@ def main(args): topic="Example Topic", conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" - "# Section 2\n## Subsection 1\n## Subsection 2\n" - "# Section 3", + "# Section 2\n## Subsection 1\n## Subsection 2\n" + "# Section 3", outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" - "# New Section 2\n" - "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3" + "# New Section 2\n" + "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", ) - runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [write_page_outline_example] + runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ + write_page_outline_example + ] write_section_example = Example( info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", topic="Example Topic", section="Example Section", output="# Example Topic\n## Subsection 1\n" - "This is an example sentence [1]. This is another example sentence [2][3].\n" - "## Subsection 2\nThis is one more example sentence [1]." + "This is an example sentence [1]. This is another example sentence [2][3].\n" + "## Subsection 2\nThis is one more example sentence [1].", ) - runner.storm_article_generation.section_gen.write_section.demos = [write_section_example] + runner.storm_article_generation.section_gen.write_section.demos = [ + write_section_example + ] - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -136,44 +151,90 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--url', type=str, default='http://localhost', - help='URL of the Ollama server.') - parser.add_argument('--port', type=int, default=11434, - help='Port of the Ollama server.') - parser.add_argument('--model', type=str, default='llama3:latest', - help='Model of the Ollama server.') - parser.add_argument('--output-dir', type=str, default='./results/ollama', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--url", type=str, default="http://localhost", help="URL of the Ollama server." + ) + parser.add_argument( + "--port", type=int, default=11434, help="Port of the Ollama server." + ) + parser.add_argument( + "--model", type=str, default="llama3:latest", help="Model of the Ollama server." + ) + parser.add_argument( + "--output-dir", + type=str, + default="./results/ollama", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/run_storm_wiki_serper.py b/examples/run_storm_wiki_serper.py new file mode 100644 index 00000000..70c2ac1d --- /dev/null +++ b/examples/run_storm_wiki_serper.py @@ -0,0 +1,175 @@ +""" +STORM Wiki pipeline powered by Claude family models and serper search engine. +You need to set up the following environment variables to run this script: + - ANTHROPIC_API_KEY: Anthropic API key + - SERPER_API_KEY: Serper.dev api key + +Output will be structured as below +args.output_dir/ + topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash + conversation_log.json # Log of information-seeking conversation + raw_search_results.json # Raw search results from search engine + direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge + storm_gen_outline.txt # Outline refined with collected information + url_to_info.json # Sources that are used in the final article + storm_gen_article.txt # Final article generated + storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) +""" + +import os +from argparse import ArgumentParser + +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) +from knowledge_storm.lm import ClaudeModel +from knowledge_storm.rm import SerperRM +from knowledge_storm.utils import load_api_key + + +def main(args): + load_api_key(toml_file_path="secrets.toml") + lm_configs = STORMWikiLMConfigs() + claude_kwargs = { + "api_key": os.getenv("ANTHROPIC_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, + } + + # STORM is a LM system so different components can be powered by different models. + # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm + # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models + # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm + # which is responsible for generating sections with citations. + conv_simulator_lm = ClaudeModel( + model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs + ) + question_asker_lm = ClaudeModel( + model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs + ) + outline_gen_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs + ) + article_gen_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs + ) + article_polish_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs + ) + + lm_configs.set_conv_simulator_lm(conv_simulator_lm) + lm_configs.set_question_asker_lm(question_asker_lm) + lm_configs.set_outline_gen_lm(outline_gen_lm) + lm_configs.set_article_gen_lm(article_gen_lm) + lm_configs.set_article_polish_lm(article_polish_lm) + + engine_args = STORMWikiRunnerArguments( + output_dir=args.output_dir, + max_conv_turn=args.max_conv_turn, + max_perspective=args.max_perspective, + search_top_k=args.search_top_k, + max_thread_num=args.max_thread_num, + ) + # Documentation to generate the data is available here: + # https://serper.dev/playground + # Important to note that tbs(date range is hardcoded values). + # num is results per pages and is recommended to use in increments of 10(10, 20, etc). + # page is how many pages will be searched. + # h1 is where the google search will orginate from. + topic = input("topic: ") + data = {"autocorrect": True, "num": 10, "page": 1} + rm = SerperRM(serper_search_api_key=os.getenv("SERPER_API_KEY"), query_params=data) + + runner = STORMWikiRunner(engine_args, lm_configs, rm) + + runner.run( + topic=topic, + do_research=args.do_research, + do_generate_outline=args.do_generate_outline, + do_generate_article=args.do_generate_article, + do_polish_article=args.do_polish_article, + ) + runner.post_run() + runner.summary() + + +if __name__ == "__main__": + parser = ArgumentParser() + # global arguments + parser.add_argument( + "--output-dir", + type=str, + default="./results/claude", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "serper"], + help="The search engine API to use for retrieving information.", + ) + # stage of the pipeline + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) + # hyperparameters for the pre-writing stage + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) + # hyperparameters for the writing stage + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) + + main(parser.parse_args()) diff --git a/knowledge_storm/__init__.py b/knowledge_storm/__init__.py index f1fd18ea..74dcabbe 100644 --- a/knowledge_storm/__init__.py +++ b/knowledge_storm/__init__.py @@ -1,5 +1,5 @@ from .storm_wiki.engine import ( STORMWikiLMConfigs, STORMWikiRunnerArguments, - STORMWikiRunner + STORMWikiRunner, ) diff --git a/knowledge_storm/interface.py b/knowledge_storm/interface.py index 03df2fb6..f6c11bd9 100644 --- a/knowledge_storm/interface.py +++ b/knowledge_storm/interface.py @@ -5,7 +5,9 @@ from collections import OrderedDict from typing import Dict, List, Optional, Union -logging.basicConfig(level=logging.INFO, format='%(name)s : %(levelname)-8s : %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(name)s : %(levelname)-8s : %(message)s" +) logger = logging.getLogger(__name__) @@ -70,7 +72,9 @@ class Article(ABC): def __init__(self, topic_name): self.root = ArticleSectionNode(topic_name) - def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]: + def find_section( + self, node: ArticleSectionNode, name: str + ) -> Optional[ArticleSectionNode]: """ Return the node of the section given the section name. @@ -152,7 +156,9 @@ def prune_empty_nodes(self, node=None): if node is None: node = self.root - node.children[:] = [child for child in node.children if self.prune_empty_nodes(child)] + node.children[:] = [ + child for child in node.children if self.prune_empty_nodes(child) + ] if (node.content is None or node.content == "") and not node.children: return None @@ -178,7 +184,9 @@ def update_search_top_k(self, k): def collect_and_reset_rm_usage(self): combined_usage = [] for attr_name in self.__dict__: - if '_rm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'): + if "_rm" in attr_name and hasattr( + getattr(self, attr_name), "get_usage_and_reset" + ): combined_usage.append(getattr(self, attr_name).get_usage_and_reset()) name_to_usage = {} @@ -240,7 +248,9 @@ class OutlineGenerationModule(ABC): """ @abstractmethod - def generate_outline(self, topic: str, information_table: InformationTable, **kwargs) -> Article: + def generate_outline( + self, topic: str, information_table: InformationTable, **kwargs + ) -> Article: """ Generate outline for the article. Required arguments include: topic: the topic of interest @@ -263,11 +273,13 @@ class ArticleGenerationModule(ABC): """ @abstractmethod - def generate_article(self, - topic: str, - information_table: InformationTable, - article_with_outline: Article, - **kwargs) -> Article: + def generate_article( + self, + topic: str, + information_table: InformationTable, + article_with_outline: Article, + **kwargs, + ) -> Article: """ Generate article. Required arguments include: topic: the topic of interest @@ -312,14 +324,15 @@ def wrapper(self, *args, **kwargs): class LMConfigs(ABC): """Abstract base class for language model configurations of the knowledge curation engine. - The language model used for each part should be declared with a suffix '_lm' in the attribute name.""" + The language model used for each part should be declared with a suffix '_lm' in the attribute name. + """ def __init__(self): pass def init_check(self): for attr_name in self.__dict__: - if '_lm' in attr_name and getattr(self, attr_name) is None: + if "_lm" in attr_name and getattr(self, attr_name) is None: logging.warning( f"Language model for {attr_name} is not initialized. Please call set_{attr_name}()" ) @@ -327,7 +340,7 @@ def init_check(self): def collect_and_reset_lm_history(self): history = [] for attr_name in self.__dict__: - if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'history'): + if "_lm" in attr_name and hasattr(getattr(self, attr_name), "history"): history.extend(getattr(self, attr_name).history) getattr(self, attr_name).history = [] @@ -336,7 +349,9 @@ def collect_and_reset_lm_history(self): def collect_and_reset_lm_usage(self): combined_usage = [] for attr_name in self.__dict__: - if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'): + if "_lm" in attr_name and hasattr( + getattr(self, attr_name), "get_usage_and_reset" + ): combined_usage.append(getattr(self, attr_name).get_usage_and_reset()) model_name_to_usage = {} @@ -345,8 +360,12 @@ def collect_and_reset_lm_usage(self): if model_name not in model_name_to_usage: model_name_to_usage[model_name] = tokens else: - model_name_to_usage[model_name]['prompt_tokens'] += tokens['prompt_tokens'] - model_name_to_usage[model_name]['completion_tokens'] += tokens['completion_tokens'] + model_name_to_usage[model_name]["prompt_tokens"] += tokens[ + "prompt_tokens" + ] + model_name_to_usage[model_name]["completion_tokens"] += tokens[ + "completion_tokens" + ] return model_name_to_usage @@ -354,8 +373,9 @@ def log(self): return OrderedDict( { - attr_name: getattr(self, attr_name).kwargs for attr_name in self.__dict__ if - '_lm' in attr_name and hasattr(getattr(self, attr_name), 'kwargs') + attr_name: getattr(self, attr_name).kwargs + for attr_name in self.__dict__ + if "_lm" in attr_name and hasattr(getattr(self, attr_name), "kwargs") } ) @@ -379,16 +399,21 @@ def wrapper(*args, **kwargs): self.time[func.__name__] = execution_time logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds") self.lm_cost[func.__name__] = self.lm_configs.collect_and_reset_lm_usage() - if hasattr(self, 'retriever'): - self.rm_cost[func.__name__] = self.retriever.collect_and_reset_rm_usage() + if hasattr(self, "retriever"): + self.rm_cost[func.__name__] = ( + self.retriever.collect_and_reset_rm_usage() + ) return result return wrapper def apply_decorators(self): """Apply decorators to methods that need them.""" - methods_to_decorate = [method_name for method_name in dir(self) - if callable(getattr(self, method_name)) and method_name.startswith('run_')] + methods_to_decorate = [ + method_name + for method_name in dir(self) + if callable(getattr(self, method_name)) and method_name.startswith("run_") + ] for method_name in methods_to_decorate: original_method = getattr(self, method_name) decorated_method = self.log_execution_time_and_lm_rm_usage(original_method) diff --git a/knowledge_storm/lm.py b/knowledge_storm/lm.py index e9c50852..1aa34d24 100644 --- a/knowledge_storm/lm.py +++ b/knowledge_storm/lm.py @@ -9,7 +9,10 @@ import requests from dsp import ERRORS, backoff_hdlr, giveup_hdlr from dsp.modules.hf import openai_to_hf -from dsp.modules.hf_client import send_hfvllm_request_v00, send_hftgi_request_v01_wrapped +from dsp.modules.hf_client import ( + send_hfvllm_request_v00, + send_hftgi_request_v01_wrapped, +) from transformers import AutoTokenizer try: @@ -22,11 +25,11 @@ class OpenAIModel(dspy.OpenAI): """A wrapper class for dspy.OpenAI.""" def __init__( - self, - model: str = "gpt-3.5-turbo-instruct", - api_key: Optional[str] = None, - model_type: Literal["chat", "text"] = None, - **kwargs + self, + model: str = "gpt-3.5-turbo-instruct", + api_key: Optional[str] = None, + model_type: Literal["chat", "text"] = None, + **kwargs, ): super().__init__(model=model, api_key=api_key, model_type=model_type, **kwargs) self._token_usage_lock = threading.Lock() @@ -35,17 +38,20 @@ def __init__( def log_usage(self, response): """Log the total tokens from the OpenAI API response.""" - usage_data = response.get('usage') + usage_data = response.get("usage") if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get('prompt_tokens', 0) - self.completion_tokens += usage_data.get('completion_tokens', 0) + self.prompt_tokens += usage_data.get("prompt_tokens", 0) + self.completion_tokens += usage_data.get("completion_tokens", 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.kwargs.get('model') or self.kwargs.get('engine'): - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.kwargs.get("model") + or self.kwargs.get("engine"): { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -53,11 +59,11 @@ def get_usage_and_reset(self): return usage def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, ) -> list[dict[str, Any]]: """Copied from dspy/dsp/modules/gpt3.py with the addition of tracking token usage.""" @@ -109,11 +115,11 @@ class DeepSeekModel(dspy.OpenAI): """A wrapper class for DeepSeek API, compatible with dspy.OpenAI.""" def __init__( - self, - model: str = "deepseek-chat", - api_key: Optional[str] = None, - api_base: str = "https://api.deepseek.com", - **kwargs + self, + model: str = "deepseek-chat", + api_key: Optional[str] = None, + api_base: str = "https://api.deepseek.com", + **kwargs, ): super().__init__(model=model, api_key=api_key, api_base=api_base, **kwargs) self._token_usage_lock = threading.Lock() @@ -123,21 +129,25 @@ def __init__( self.api_key = api_key or os.getenv("DEEPSEEK_API_KEY") self.api_base = api_base if not self.api_key: - raise ValueError("DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY") + raise ValueError( + "DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY" + ) def log_usage(self, response): """Log the total tokens from the DeepSeek API response.""" - usage_data = response.get('usage') + usage_data = response.get("usage") if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get('prompt_tokens', 0) - self.completion_tokens += usage_data.get('completion_tokens', 0) + self.prompt_tokens += usage_data.get("prompt_tokens", 0) + self.completion_tokens += usage_data.get("completion_tokens", 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.model: { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -154,23 +164,25 @@ def _create_completion(self, prompt: str, **kwargs): """Create a completion using the DeepSeek API.""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}" + "Authorization": f"Bearer {self.api_key}", } data = { "model": self.model, "messages": [{"role": "user", "content": prompt}], - **kwargs + **kwargs, } - response = requests.post(f"{self.api_base}/v1/chat/completions", headers=headers, json=data) + response = requests.post( + f"{self.api_base}/v1/chat/completions", headers=headers, json=data + ) response.raise_for_status() return response.json() def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, ) -> list[dict[str, Any]]: """Call the DeepSeek API to generate completions.""" assert only_completed, "for now" @@ -196,35 +208,46 @@ def __call__( class AzureOpenAIModel(dspy.AzureOpenAI): """A wrapper class for dspy.AzureOpenAI.""" + def __init__( - self, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - model: str = "gpt-3.5-turbo-instruct", - api_key: Optional[str] = None, - model_type: Literal["chat", "text"] = "chat", - **kwargs, + self, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + model: str = "gpt-3.5-turbo-instruct", + api_key: Optional[str] = None, + model_type: Literal["chat", "text"] = "chat", + **kwargs, ): super().__init__( - api_base=api_base, api_version=api_version, model=model, api_key=api_key, model_type=model_type, **kwargs) + api_base=api_base, + api_version=api_version, + model=model, + api_key=api_key, + model_type=model_type, + **kwargs, + ) self._token_usage_lock = threading.Lock() self.prompt_tokens = 0 self.completion_tokens = 0 def log_usage(self, response): """Log the total tokens from the OpenAI API response. - Override log_usage() in dspy.AzureOpenAI for tracking accumulated token usage.""" - usage_data = response.get('usage') + Override log_usage() in dspy.AzureOpenAI for tracking accumulated token usage. + """ + usage_data = response.get("usage") if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get('prompt_tokens', 0) - self.completion_tokens += usage_data.get('completion_tokens', 0) + self.prompt_tokens += usage_data.get("prompt_tokens", 0) + self.completion_tokens += usage_data.get("completion_tokens", 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.kwargs.get('model') or self.kwargs.get('engine'): - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.kwargs.get("model") + or self.kwargs.get("engine"): { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -236,11 +259,11 @@ class ClaudeModel(dspy.dsp.modules.lm.LM): """Copied from dspy/dsp/modules/anthropic.py with the addition of tracking token usage.""" def __init__( - self, - model: str, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - **kwargs, + self, + model: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + **kwargs, ): super().__init__(model) try: @@ -249,12 +272,21 @@ def __init__( raise ImportError("Claude requires `pip install anthropic`.") from err self.provider = "anthropic" - self.api_key = api_key = os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key - self.api_base = "https://api.anthropic.com/v1/messages" if api_base is None else api_base - self.kwargs = {"temperature": kwargs.get("temperature", 0.0), - "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), "top_p": kwargs.get("top_p", 1.0), - "top_k": kwargs.get("top_k", 1), "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), - **kwargs, "model": model} + self.api_key = api_key = ( + os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key + ) + self.api_base = ( + "https://api.anthropic.com/v1/messages" if api_base is None else api_base + ) + self.kwargs = { + "temperature": kwargs.get("temperature", 0.0), + "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), + "top_p": kwargs.get("top_p", 1.0), + "top_k": kwargs.get("top_k", 1), + "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), + **kwargs, + "model": model, + } self.history: list[dict[str, Any]] = [] self.client = Anthropic(api_key=api_key) self.model = model @@ -274,8 +306,10 @@ def log_usage(self, response): def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.model: { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -307,7 +341,7 @@ def basic_request(self, prompt: str, **kwargs): "usage": { "input_tokens": response.usage.input_tokens, "output_tokens": response.usage.output_tokens, - } + }, }, "kwargs": kwargs, "raw_kwargs": raw_kwargs, @@ -377,10 +411,7 @@ def _generate(self, prompt, **kwargs): # "max_tokens": kwargs["max_tokens"], # "temperature": kwargs["temperature"], # } - payload = { - "prompt": prompt, - **kwargs - } + payload = {"prompt": prompt, **kwargs} response = send_hfvllm_request_v00( f"{self.url}/v1/completions", @@ -413,11 +444,17 @@ def __init__(self, model, port, url="http://localhost", **kwargs): super().__init__(model=model, base_url=f"{url}:{port}", **kwargs) # Store additional kwargs for the generate method. self.kwargs = {**self.kwargs, **kwargs} - + class TGIClient(dspy.HFClientTGI): def __init__(self, model, port, url, http_request_kwargs=None, **kwargs): - super().__init__(model=model, port=port, url=url, http_request_kwargs=http_request_kwargs, **kwargs) + super().__init__( + model=model, + port=port, + url=url, + http_request_kwargs=http_request_kwargs, + **kwargs, + ) def _generate(self, prompt, **kwargs): """Copied from dspy/dsp/modules/hf_client.py with the addition of removing hard-coded parameters.""" @@ -456,8 +493,8 @@ def _generate(self, prompt, **kwargs): completions = [json_response["generated_text"]] if ( - "details" in json_response - and "best_of_sequences" in json_response["details"] + "details" in json_response + and "best_of_sequences" in json_response["details"] ): completions += [ x["generated_text"] @@ -474,13 +511,22 @@ def _generate(self, prompt, **kwargs): class TogetherClient(dspy.HFModel): """A wrapper class for dspy.Together.""" - def __init__(self, model, apply_tokenizer_chat_template=False, hf_tokenizer_name=None, **kwargs): + def __init__( + self, + model, + apply_tokenizer_chat_template=False, + hf_tokenizer_name=None, + **kwargs, + ): """Copied from dspy/dsp/modules/hf_client.py with the support of applying tokenizer chat template.""" super().__init__(model=model, is_client=True) self.session = requests.Session() - self.api_base = "https://api.together.xyz/v1/completions" if os.getenv( - "TOGETHER_API_BASE") is None else os.getenv("TOGETHER_API_BASE") + self.api_base = ( + "https://api.together.xyz/v1/completions" + if os.getenv("TOGETHER_API_BASE") is None + else os.getenv("TOGETHER_API_BASE") + ) self.token = os.getenv("TOGETHER_API_KEY") self.model = model @@ -492,7 +538,9 @@ def __init__(self, model, apply_tokenizer_chat_template=False, hf_tokenizer_name logging.info("Loading huggingface tokenizer.") if hf_tokenizer_name is None: hf_tokenizer_name = self.model - self.tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_name, cache_dir=kwargs.get("cache_dir", None)) + self.tokenizer = AutoTokenizer.from_pretrained( + hf_tokenizer_name, cache_dir=kwargs.get("cache_dir", None) + ) stop_default = "\n\n---" @@ -512,17 +560,19 @@ def __init__(self, model, apply_tokenizer_chat_template=False, hf_tokenizer_name def log_usage(self, response): """Log the total tokens from the OpenAI API response.""" - usage_data = response.get('usage') + usage_data = response.get("usage") if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get('prompt_tokens', 0) - self.completion_tokens += usage_data.get('completion_tokens', 0) + self.prompt_tokens += usage_data.get("prompt_tokens", 0) + self.completion_tokens += usage_data.get("completion_tokens", 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.model: { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -547,14 +597,18 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): top_k = kwargs.get("top_k", 50) repetition_penalty = kwargs.get("repetition_penalty", 1) if self.apply_tokenizer_chat_template: - prompt = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False) + prompt = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], tokenize=False + ) # prompt = f"[INST]{prompt}[/INST]" if self.use_inst_template else prompt if use_chat_api: url = f"{self.api_base}/chat/completions" messages = [ - {"role": "system", - "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections."}, + { + "role": "system", + "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections.", + }, {"role": "user", "content": prompt}, ] body = { @@ -587,9 +641,13 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): self.log_usage(resp_json) if use_chat_api: # completions = [resp_json['output'].get('choices', [])[0].get('message', {}).get('content', "")] - completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")] + completions = [ + resp_json.get("choices", [])[0] + .get("message", {}) + .get("content", "") + ] else: # completions = [resp_json['output'].get('choices', [])[0].get('text', "")] - completions = [resp_json.get('choices', [])[0].get('text', "")] + completions = [resp_json.get("choices", [])[0].get("text", "")] response = {"prompt": prompt, "choices": [{"text": c} for c in completions]} return response diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index 2c5fdf57..441ea570 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -12,14 +12,17 @@ from qdrant_client import QdrantClient, models from tqdm import tqdm import requests -from .utils import WebPageHelper +import json +from utils import WebPageHelper class YouRM(dspy.Retrieve): def __init__(self, ydc_api_key=None, k=3, is_valid_source: Callable = None): super().__init__(k=k) if not ydc_api_key and not os.environ.get("YDC_API_KEY"): - raise RuntimeError("You must supply ydc_api_key or set environment variable YDC_API_KEY") + raise RuntimeError( + "You must supply ydc_api_key or set environment variable YDC_API_KEY" + ) elif ydc_api_key: self.ydc_api_key = ydc_api_key else: @@ -36,9 +39,11 @@ def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {'YouRM': usage} + return {"YouRM": usage} - def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []): + def forward( + self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] + ): """Search with You.com for self.k top passages for query or queries Args: @@ -64,21 +69,30 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st ).json() authoritative_results = [] - for r in results['hits']: - if self.is_valid_source(r['url']) and r['url'] not in exclude_urls: + for r in results["hits"]: + if self.is_valid_source(r["url"]) and r["url"] not in exclude_urls: authoritative_results.append(r) - if 'hits' in results: - collected_results.extend(authoritative_results[:self.k]) + if "hits" in results: + collected_results.extend(authoritative_results[: self.k]) except Exception as e: - logging.error(f'Error occurs when searching query {query}: {e}') + logging.error(f"Error occurs when searching query {query}: {e}") return collected_results class BingSearch(dspy.Retrieve): - def __init__(self, bing_search_api_key=None, k=3, is_valid_source: Callable = None, - min_char_count: int = 150, snippet_chunk_size: int = 1000, webpage_helper_max_threads=10, - mkt='en-US', language='en', **kwargs): + def __init__( + self, + bing_search_api_key=None, + k=3, + is_valid_source: Callable = None, + min_char_count: int = 150, + snippet_chunk_size: int = 1000, + webpage_helper_max_threads=10, + mkt="en-US", + language="en", + **kwargs, + ): """ Params: min_char_count: Minimum character count for the article to be considered valid. @@ -90,22 +104,18 @@ def __init__(self, bing_search_api_key=None, k=3, is_valid_source: Callable = No super().__init__(k=k) if not bing_search_api_key and not os.environ.get("BING_SEARCH_API_KEY"): raise RuntimeError( - "You must supply bing_search_subscription_key or set environment variable BING_SEARCH_API_KEY") + "You must supply bing_search_subscription_key or set environment variable BING_SEARCH_API_KEY" + ) elif bing_search_api_key: self.bing_api_key = bing_search_api_key else: self.bing_api_key = os.environ["BING_SEARCH_API_KEY"] self.endpoint = "https://api.bing.microsoft.com/v7.0/search" - self.params = { - 'mkt': mkt, - "setLang": language, - "count": k, - **kwargs - } + self.params = {"mkt": mkt, "setLang": language, "count": k, **kwargs} self.webpage_helper = WebPageHelper( min_char_count=min_char_count, snippet_chunk_size=snippet_chunk_size, - max_thread_num=webpage_helper_max_threads + max_thread_num=webpage_helper_max_threads, ) self.usage = 0 @@ -119,9 +129,11 @@ def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {'BingSearch': usage} + return {"BingSearch": usage} - def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []): + def forward( + self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] + ): """Search with Bing for self.k top passages for query or queries Args: @@ -145,22 +157,26 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st for query in queries: try: results = requests.get( - self.endpoint, - headers=headers, - params={**self.params, 'q': query} + self.endpoint, headers=headers, params={**self.params, "q": query} ).json() - for d in results['webPages']['value']: - if self.is_valid_source(d['url']) and d['url'] not in exclude_urls: - url_to_results[d['url']] = {'url': d['url'], 'title': d['name'], 'description': d['snippet']} + for d in results["webPages"]["value"]: + if self.is_valid_source(d["url"]) and d["url"] not in exclude_urls: + url_to_results[d["url"]] = { + "url": d["url"], + "title": d["name"], + "description": d["snippet"], + } except Exception as e: - logging.error(f'Error occurs when searching query {query}: {e}') + logging.error(f"Error occurs when searching query {query}: {e}") - valid_url_to_snippets = self.webpage_helper.urls_to_snippets(list(url_to_results.keys())) + valid_url_to_snippets = self.webpage_helper.urls_to_snippets( + list(url_to_results.keys()) + ) collected_results = [] for url in valid_url_to_snippets: r = url_to_results[url] - r['snippets'] = valid_url_to_snippets[url]['snippets'] + r["snippets"] = valid_url_to_snippets[url]["snippets"] collected_results.append(r) return collected_results @@ -178,13 +194,15 @@ class VectorRM(dspy.Retrieve): The documents should be stored in a CSV file. """ - def __init__(self, - collection_name: str = "my_documents", - embedding_model: str = 'BAAI/bge-m3', - device: str = "mps", - k: int = 3, - chunk_size: int = 500, - chunk_overlap: int = 100): + def __init__( + self, + collection_name: str = "my_documents", + embedding_model: str = "BAAI/bge-m3", + device: str = "mps", + k: int = 3, + chunk_size: int = 500, + chunk_overlap: int = 100, + ): """ Params: collection_name: Name of the Qdrant collection. @@ -200,7 +218,9 @@ def __init__(self, model_kwargs = {"device": device} encode_kwargs = {"normalize_embeddings": True} self.model = HuggingFaceEmbeddings( - model_name=embedding_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs + model_name=embedding_model, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs, ) self.chunk_size = chunk_size @@ -217,18 +237,24 @@ def _check_create_collection(self): if self.client is None: raise ValueError("Qdrant client is not initialized.") if self.client.collection_exists(collection_name=f"{self.collection_name}"): - print(f"Collection {self.collection_name} exists. Loading the collection...") + print( + f"Collection {self.collection_name} exists. Loading the collection..." + ) self.qdrant = Qdrant( client=self.client, collection_name=self.collection_name, embeddings=self.model, ) else: - print(f"Collection {self.collection_name} does not exist. Creating the collection...") + print( + f"Collection {self.collection_name} does not exist. Creating the collection..." + ) # create the collection self.client.create_collection( collection_name=f"{self.collection_name}", - vectors_config=models.VectorParams(size=1024, distance=models.Distance.COSINE), + vectors_config=models.VectorParams( + size=1024, distance=models.Distance.COSINE + ), ) self.qdrant = Qdrant( client=self.client, @@ -274,13 +300,13 @@ def init_offline_vector_db(self, vector_store_path: str): raise ValueError(f"Error occurs when loading the vector store: {e}") def update_vector_store( - self, - file_path: str, - content_column: str, - title_column: str = "title", - url_column: str = "url", - desc_column: str = "description", - batch_size: int = 64 + self, + file_path: str, + content_column: str, + title_column: str = "title", + url_column: str = "url", + desc_column: str = "description", + batch_size: int = 64, ): """ Takes a CSV file where each row is a document and has columns for content, title, url, and description. @@ -297,7 +323,7 @@ def update_vector_store( if file_path is None: raise ValueError("Please provide a file path.") # check if the file is a csv file - if not file_path.endswith('.csv'): + if not file_path.endswith(".csv"): raise ValueError(f"Not valid file format. Please provide a csv file.") if content_column is None: raise ValueError("Please provide the name of the content column.") @@ -311,7 +337,9 @@ def update_vector_store( df = pd.read_csv(file_path) # check that content column exists and url column exists if content_column not in df.columns: - raise ValueError(f"Content column {content_column} not found in the csv file.") + raise ValueError( + f"Content column {content_column} not found in the csv file." + ) if url_column not in df.columns: raise ValueError(f"URL column {url_column} not found in the csv file.") @@ -319,16 +347,17 @@ def update_vector_store( Document( page_content=row[content_column], metadata={ - "title": row.get(title_column, ''), + "title": row.get(title_column, ""), "url": row[url_column], - "description": row.get(desc_column, ''), - } + "description": row.get(desc_column, ""), + }, ) - for row in df.to_dict(orient='records') + for row in df.to_dict(orient="records") ] # split the documents from langchain_text_splitters import RecursiveCharacterTextSplitter + text_splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, @@ -346,7 +375,7 @@ def update_vector_store( " ", "\u200B", # Zero-width space "", - ] + ], ) split_documents = text_splitter.split_documents(documents) @@ -364,7 +393,7 @@ def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {'VectorRM': usage} + return {"VectorRM": usage} def get_vector_count(self): """ @@ -397,77 +426,130 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st related_docs = self.qdrant.similarity_search_with_score(query, k=self.k) for i in range(len(related_docs)): doc = related_docs[i][0] - collected_results.append({ - 'description': doc.metadata['description'], - 'snippets': [doc.page_content], - 'title': doc.metadata['title'], - 'url': doc.metadata['url'], - }) + collected_results.append( + { + "description": doc.metadata["description"], + "snippets": [doc.page_content], + "title": doc.metadata["title"], + "url": doc.metadata["url"], + } + ) return collected_results -class SerperRM(dspy.Retrieve) - def __init__(self, serper_search_api_key=None, query_params=None, minibatch=None): - super().__init__() - if not self.serper_search_api_key and not os.environ.get("SERPER_API_KEY"): - raise RuntimeError( - "You must supply serper_search_api_key or set environment variable SERPER_API_KEY") - elif self.serper_search_api_key: - self.serper_search_api_key = serper_search_api_key - else: - self.serper_search_api_key = os.environ["SERPER_API_KEY"] - self.base_url = "https://google.serper.dev" - def validate_input(self): - if(self.query_params.get("type") == None or type(self.query_params.get("type")) != str): - raise RuntimeError("A type must be provided.") - elif(len(self.query_params.get("gl")) > 2 or self.query_params.get("gl") == None or type(self.query_params.get("gl")) != str): - raise RuntimeError("Country code was not provided.") - elif(self.query_params.get("hl") == None or type(self.query_params.get("hl")) != str): - raise RuntimeError("Language was not provided.") - elif(self.query_params.get("tbs") == None or type(self.query_params.get("tbs")) != str): - raise RuntimeError("Date range was not provided.") - elif(self.query_params.get("autocorrect") == None or type(self.query_params.get("autocorrect")) != bool): - raise RuntimeError("Autocorrect boolean was not provided.") - elif(self.query_params.get("results") == None or type(self.query_params.get("results")) != int): - raise RuntimeError("Number of results returned was not provided.") - elif(self.query_params.get("page") == None or type(self.query_params.get("page")) != int): - raise RuntimeError("Number of pages to return is not provided.") - elif(self.query_params.get("page") == None or type(self.query_params.get("page")) != int): - raise RuntimeError("Number of pages to return is not provided.") - elif(self.minibatch == None or type(self.query_params) == Dict): - raise RuntimeError("Minibatch is enabled, however query_params is a dictionary, will need to be converted to list to be able to be used.") - def runner(self): - match self.query_params.get("type"): - case "search": - self.run_process("type") - case "images": - self.run_process("images") - case "videos": - self.run_process("videos") - case "places": - self.run_process("places") - case "maps": - self.run_process("maps") - case "news": - self.run_process("news") - case "shopping" - self.run_process("shopping") - case "scholar" - self.run_process("scholar") - case "patents" - self.run_process("patents") - case "autocomplete": - self.run_process("autocomplete") - case "" - def run_process(self, process_name=None): - print(f"Beginning {process_name} process...") - self.search_url = f"{self.base_url}/{process_name}" + +class SerperRM(dspy.Retrieve): + """Retrieve information from custom queries using Serper.dev. + + To be compatible with STORM, the results should have the following fields: + - snippet: Snippets that will be used for the document + - title: The title of the document. + - url: The URL of the document. STORM use url as the unique identifier of the document, so ensure different + documents have different urls. + - description (optional): The description of the document. + + """ + + def __init__(self, serper_search_api_key=None, query_params=None): + super().__init__() + self.usage = 0 + self.query_params = query_params + self.serper_search_api_key = serper_search_api_key + if not self.serper_search_api_key and not os.environ.get("SERPER_API_KEY"): + raise RuntimeError( + "You must supply a serper_search_api_key param or set environment variable SERPER_API_KEY" + ) + + elif self.serper_search_api_key: + self.serper_search_api_key = serper_search_api_key + + else: + self.serper_search_api_key = os.environ["SERPER_API_KEY"] + + self.base_url = "https://google.serper.dev" + + def serper_runner(self, query_params): + self.search_url = f"{self.base_url}/search" + headers = { "X-API-KEY": self.serper_search_api_key, - "Content-Type": "application/json" + "Content-Type": "application/json", } - response = requests.request("POST", self.search_url, headers=headers, data=self.query_params) + response = requests.request( + "POST", self.search_url, headers=headers, json=query_params + ) + if response == None: - raise RuntimeError(f"Error had occured while running the process {process_name}.\n Error is {response.reason}, had failed with status code {response.status_code}") + raise RuntimeError( + f"Error had occured while running the process {process_name}.\n Error is {response.reason}, had failed with status code {response.status_code}" + ) + return response.json() + + def get_usage_and_reset(self): + usage = self.usage + self.usage = 0 + return {"SerperRM": usage} + + def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]): + """ + Calls the API and searches for the query passed in. + + Args: + query_or_queries (Union[str, List[str]]): The query or queries to search for. + exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect. + + Returns: + a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url' + """ + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + + self.usage += len(queries) + self.results = [] + collected_results = [] + for query in queries: + if query == "Queries:": + continue + query_params = self.query_params + query_params["q"] = query + query_params["type"] = "search" + self.result = self.serper_runner(query_params) + self.results.append(self.result) + + collected_results = [] + + for result in self.results: + try: + organic_results = result.get("organic") + + knowledge_graph = result.get("knowledgeGraph") + for organic in organic_results: + snippets = [] + snippets.append(organic.get("snippet")) + if knowledge_graph != None: + collected_results.append( + { + "snippets": snippets, + "title": organic.get("title"), + "url": organic.get("link"), + "description": knowledge_graph.get("description"), + } + ) + else: + collected_results.append( + { + "snippets": snippets, + "title": result.get("title"), + "url": result.get("link"), + "description": "", + } + ) + except: + continue + + return collected_results diff --git a/knowledge_storm/storm_wiki/engine.py b/knowledge_storm/storm_wiki/engine.py index e0c8dfcc..746a07b0 100644 --- a/knowledge_storm/storm_wiki/engine.py +++ b/knowledge_storm/storm_wiki/engine.py @@ -28,43 +28,52 @@ class STORMWikiLMConfigs(LMConfigs): """ def __init__(self): - self.conv_simulator_lm = None # LLM used in conversation simulator except for question asking. + self.conv_simulator_lm = ( + None # LLM used in conversation simulator except for question asking. + ) self.question_asker_lm = None # LLM used in question asking. self.outline_gen_lm = None # LLM used in outline generation. self.article_gen_lm = None # LLM used in article generation. self.article_polish_lm = None # LLM used in article polishing. def init_openai_model( - self, - openai_api_key: str, - openai_type: Literal["openai", "azure"], - api_base: Optional[str] = None, - api_version: Optional[str] = None, - temperature: Optional[float] = 1.0, - top_p: Optional[float] = 0.9 + self, + openai_api_key: str, + openai_type: Literal["openai", "azure"], + api_base: Optional[str] = None, + api_version: Optional[str] = None, + temperature: Optional[float] = 1.0, + top_p: Optional[float] = 0.9, ): """Legacy: Corresponding to the original setup in the NAACL'24 paper.""" openai_kwargs = { - 'api_key': openai_api_key, - 'api_provider': openai_type, - 'temperature': temperature, - 'top_p': top_p, - 'api_base': None + "api_key": openai_api_key, + "api_provider": openai_type, + "temperature": temperature, + "top_p": top_p, + "api_base": None, } - if openai_type and openai_type == 'openai': - self.conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo-instruct', - max_tokens=500, **openai_kwargs) - self.question_asker_lm = OpenAIModel(model='gpt-3.5-turbo', - max_tokens=500, **openai_kwargs) + if openai_type and openai_type == "openai": + self.conv_simulator_lm = OpenAIModel( + model="gpt-3.5-turbo-instruct", max_tokens=500, **openai_kwargs + ) + self.question_asker_lm = OpenAIModel( + model="gpt-3.5-turbo", max_tokens=500, **openai_kwargs + ) # 1/12/2024: Update gpt-4 to gpt-4-1106-preview. (Currently keep the original setup when using azure.) - self.outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview', - max_tokens=400, **openai_kwargs) - self.article_gen_lm = OpenAIModel(model='gpt-4o-2024-05-13', - max_tokens=700, **openai_kwargs) - self.article_polish_lm = OpenAIModel(model='gpt-4o-2024-05-13', - max_tokens=4000, **openai_kwargs) + self.outline_gen_lm = OpenAIModel( + model="gpt-4-0125-preview", max_tokens=400, **openai_kwargs + ) + self.article_gen_lm = OpenAIModel( + model="gpt-4o-2024-05-13", max_tokens=700, **openai_kwargs + ) + self.article_polish_lm = OpenAIModel( + model="gpt-4o-2024-05-13", max_tokens=4000, **openai_kwargs + ) else: - logging.warning('No valid OpenAI API provider is provided. Cannot use default LLM configurations.') + logging.warning( + "No valid OpenAI API provider is provided. Cannot use default LLM configurations." + ) def set_conv_simulator_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.conv_simulator_lm = model @@ -85,16 +94,21 @@ def set_article_polish_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): @dataclass class STORMWikiRunnerArguments: """Arguments for controlling the STORM Wiki pipeline.""" + output_dir: str = field( metadata={"help": "Output directory for the results."}, ) max_conv_turn: int = field( default=3, - metadata={"help": "Maximum number of questions in conversational question asking."}, + metadata={ + "help": "Maximum number of questions in conversational question asking." + }, ) max_perspective: int = field( default=3, - metadata={"help": "Maximum number of perspectives to consider in perspective-guided question asking."}, + metadata={ + "help": "Maximum number of perspectives to consider in perspective-guided question asking." + }, ) max_search_queries_per_turn: int = field( default=3, @@ -114,24 +128,27 @@ class STORMWikiRunnerArguments: ) max_thread_num: int = field( default=10, - metadata={"help": "Maximum number of threads to use. " - "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API."}, + metadata={ + "help": "Maximum number of threads to use. " + "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API." + }, ) class STORMWikiRunner(Engine): """STORM Wiki pipeline runner.""" - def __init__(self, - args: STORMWikiRunnerArguments, - lm_configs: STORMWikiLMConfigs, - rm): + def __init__( + self, args: STORMWikiRunnerArguments, lm_configs: STORMWikiLMConfigs, rm + ): super().__init__(lm_configs=lm_configs) self.args = args self.lm_configs = lm_configs self.retriever = StormRetriever(rm=rm, k=self.args.retrieve_top_k) - storm_persona_generator = StormPersonaGenerator(self.lm_configs.question_asker_lm) + storm_persona_generator = StormPersonaGenerator( + self.lm_configs.question_asker_lm + ) self.storm_knowledge_curation_module = StormKnowledgeCurationModule( retriever=self.retriever, persona_generator=storm_persona_generator, @@ -140,7 +157,7 @@ def __init__(self, max_search_queries_per_turn=self.args.max_search_queries_per_turn, search_top_k=self.args.search_top_k, max_conv_turn=self.args.max_conv_turn, - max_thread_num=self.args.max_thread_num + max_thread_num=self.args.max_thread_num, ) self.storm_outline_generation_module = StormOutlineGenerationModule( outline_gen_lm=self.lm_configs.outline_gen_lm @@ -148,73 +165,96 @@ def __init__(self, self.storm_article_generation = StormArticleGenerationModule( article_gen_lm=self.lm_configs.article_gen_lm, retrieve_top_k=self.args.retrieve_top_k, - max_thread_num=self.args.max_thread_num + max_thread_num=self.args.max_thread_num, ) self.storm_article_polishing_module = StormArticlePolishingModule( article_gen_lm=self.lm_configs.article_gen_lm, - article_polish_lm=self.lm_configs.article_polish_lm + article_polish_lm=self.lm_configs.article_polish_lm, ) self.lm_configs.init_check() self.apply_decorators() - def run_knowledge_curation_module(self, - ground_truth_url: str = "None", - callback_handler: BaseCallbackHandler = None) -> StormInformationTable: - - information_table, conversation_log = self.storm_knowledge_curation_module.research( - topic=self.topic, - ground_truth_url=ground_truth_url, - callback_handler=callback_handler, - max_perspective=self.args.max_perspective, - disable_perspective=False, - return_conversation_log=True + def run_knowledge_curation_module( + self, + ground_truth_url: str = "None", + callback_handler: BaseCallbackHandler = None, + ) -> StormInformationTable: + + information_table, conversation_log = ( + self.storm_knowledge_curation_module.research( + topic=self.topic, + ground_truth_url=ground_truth_url, + callback_handler=callback_handler, + max_perspective=self.args.max_perspective, + disable_perspective=False, + return_conversation_log=True, + ) ) - FileIOHelper.dump_json(conversation_log, os.path.join(self.article_output_dir, 'conversation_log.json')) - information_table.dump_url_to_info(os.path.join(self.article_output_dir, 'raw_search_results.json')) + FileIOHelper.dump_json( + conversation_log, + os.path.join(self.article_output_dir, "conversation_log.json"), + ) + information_table.dump_url_to_info( + os.path.join(self.article_output_dir, "raw_search_results.json") + ) return information_table - def run_outline_generation_module(self, - information_table: StormInformationTable, - callback_handler: BaseCallbackHandler = None) -> StormArticle: + def run_outline_generation_module( + self, + information_table: StormInformationTable, + callback_handler: BaseCallbackHandler = None, + ) -> StormArticle: outline, draft_outline = self.storm_outline_generation_module.generate_outline( topic=self.topic, information_table=information_table, return_draft_outline=True, - callback_handler=callback_handler + callback_handler=callback_handler, + ) + outline.dump_outline_to_file( + os.path.join(self.article_output_dir, "storm_gen_outline.txt") + ) + draft_outline.dump_outline_to_file( + os.path.join(self.article_output_dir, "direct_gen_outline.txt") ) - outline.dump_outline_to_file(os.path.join(self.article_output_dir, 'storm_gen_outline.txt')) - draft_outline.dump_outline_to_file(os.path.join(self.article_output_dir, "direct_gen_outline.txt")) return outline - def run_article_generation_module(self, - outline: StormArticle, - information_table=StormInformationTable, - callback_handler: BaseCallbackHandler = None) -> StormArticle: + def run_article_generation_module( + self, + outline: StormArticle, + information_table=StormInformationTable, + callback_handler: BaseCallbackHandler = None, + ) -> StormArticle: draft_article = self.storm_article_generation.generate_article( topic=self.topic, information_table=information_table, article_with_outline=outline, - callback_handler=callback_handler + callback_handler=callback_handler, + ) + draft_article.dump_article_as_plain_text( + os.path.join(self.article_output_dir, "storm_gen_article.txt") + ) + draft_article.dump_reference_to_file( + os.path.join(self.article_output_dir, "url_to_info.json") ) - draft_article.dump_article_as_plain_text(os.path.join(self.article_output_dir, 'storm_gen_article.txt')) - draft_article.dump_reference_to_file(os.path.join(self.article_output_dir, 'url_to_info.json')) return draft_article - def run_article_polishing_module(self, - draft_article: StormArticle, - remove_duplicate: bool = False) -> StormArticle: + def run_article_polishing_module( + self, draft_article: StormArticle, remove_duplicate: bool = False + ) -> StormArticle: polished_article = self.storm_article_polishing_module.polish_article( topic=self.topic, draft_article=draft_article, - remove_duplicate=remove_duplicate + remove_duplicate=remove_duplicate, + ) + FileIOHelper.write_str( + polished_article.to_string(), + os.path.join(self.article_output_dir, "storm_gen_article_polished.txt"), ) - FileIOHelper.write_str(polished_article.to_string(), - os.path.join(self.article_output_dir, 'storm_gen_article_polished.txt')) return polished_article def post_run(self): @@ -224,43 +264,61 @@ def post_run(self): 2. Dumping the LLM call history. """ config_log = self.lm_configs.log() - FileIOHelper.dump_json(config_log, os.path.join(self.article_output_dir, 'run_config.json')) + FileIOHelper.dump_json( + config_log, os.path.join(self.article_output_dir, "run_config.json") + ) llm_call_history = self.lm_configs.collect_and_reset_lm_history() - with open(os.path.join(self.article_output_dir, 'llm_call_history.jsonl'), 'w') as f: + with open( + os.path.join(self.article_output_dir, "llm_call_history.jsonl"), "w" + ) as f: for call in llm_call_history: - if 'kwargs' in call: - call.pop('kwargs') # All kwargs are dumped together to run_config.json. - f.write(json.dumps(call) + '\n') + if "kwargs" in call: + call.pop( + "kwargs" + ) # All kwargs are dumped together to run_config.json. + f.write(json.dumps(call) + "\n") def _load_information_table_from_local_fs(self, information_table_local_path): assert os.path.exists(information_table_local_path), makeStringRed( - f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic.") - return StormInformationTable.from_conversation_log_file(information_table_local_path) + f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic." + ) + return StormInformationTable.from_conversation_log_file( + information_table_local_path + ) def _load_outline_from_local_fs(self, topic, outline_local_path): assert os.path.exists(outline_local_path), makeStringRed( - f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic.") + f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic." + ) return StormArticle.from_outline_file(topic=topic, file_path=outline_local_path) - def _load_draft_article_from_local_fs(self, topic, draft_article_path, url_to_info_path): + def _load_draft_article_from_local_fs( + self, topic, draft_article_path, url_to_info_path + ): assert os.path.exists(draft_article_path), makeStringRed( - f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic.") + f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic." + ) assert os.path.exists(url_to_info_path), makeStringRed( - f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic.") + f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic." + ) article_text = FileIOHelper.load_str(draft_article_path) references = FileIOHelper.load_json(url_to_info_path) - return StormArticle.from_string(topic_name=topic, article_text=article_text, references=references) - - def run(self, - topic: str, - ground_truth_url: str = '', - do_research: bool = True, - do_generate_outline: bool = True, - do_generate_article: bool = True, - do_polish_article: bool = True, - remove_duplicate: bool = False, - callback_handler: BaseCallbackHandler = BaseCallbackHandler()): + return StormArticle.from_string( + topic_name=topic, article_text=article_text, references=references + ) + + def run( + self, + topic: str, + ground_truth_url: str = "", + do_research: bool = True, + do_generate_outline: bool = True, + do_generate_article: bool = True, + do_polish_article: bool = True, + remove_duplicate: bool = False, + callback_handler: BaseCallbackHandler = BaseCallbackHandler(), + ): """ Run the STORM pipeline. @@ -278,50 +336,74 @@ def run(self, remove_duplicate: If True, remove duplicated content. callback_handler: A callback handler to handle the intermediate results. """ - assert do_research or do_generate_outline or do_generate_article or do_polish_article, \ - makeStringRed( - "No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article") + assert ( + do_research + or do_generate_outline + or do_generate_article + or do_polish_article + ), makeStringRed( + "No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article" + ) self.topic = topic - self.article_dir_name = topic.replace(' ', '_').replace('/', '_') - self.article_output_dir = os.path.join(self.args.output_dir, self.article_dir_name) + self.article_dir_name = topic.replace(" ", "_").replace("/", "_") + self.article_output_dir = os.path.join( + self.args.output_dir, self.article_dir_name + ) os.makedirs(self.article_output_dir, exist_ok=True) # research module information_table: StormInformationTable = None if do_research: - information_table = self.run_knowledge_curation_module(ground_truth_url=ground_truth_url, - callback_handler=callback_handler) + information_table = self.run_knowledge_curation_module( + ground_truth_url=ground_truth_url, callback_handler=callback_handler + ) # outline generation module outline: StormArticle = None if do_generate_outline: # load information table if it's not initialized if information_table is None: information_table = self._load_information_table_from_local_fs( - os.path.join(self.article_output_dir, 'conversation_log.json')) - outline = self.run_outline_generation_module(information_table=information_table, - callback_handler=callback_handler) + os.path.join(self.article_output_dir, "conversation_log.json") + ) + outline = self.run_outline_generation_module( + information_table=information_table, callback_handler=callback_handler + ) # article generation module draft_article: StormArticle = None if do_generate_article: if information_table is None: information_table = self._load_information_table_from_local_fs( - os.path.join(self.article_output_dir, 'conversation_log.json')) + os.path.join(self.article_output_dir, "conversation_log.json") + ) if outline is None: - outline = self._load_outline_from_local_fs(topic=topic, - outline_local_path=os.path.join(self.article_output_dir, - 'storm_gen_outline.txt')) - draft_article = self.run_article_generation_module(outline=outline, - information_table=information_table, - callback_handler=callback_handler) + outline = self._load_outline_from_local_fs( + topic=topic, + outline_local_path=os.path.join( + self.article_output_dir, "storm_gen_outline.txt" + ), + ) + draft_article = self.run_article_generation_module( + outline=outline, + information_table=information_table, + callback_handler=callback_handler, + ) # article polishing module if do_polish_article: if draft_article is None: - draft_article_path = os.path.join(self.article_output_dir, 'storm_gen_article.txt') - url_to_info_path = os.path.join(self.article_output_dir, 'url_to_info.json') - draft_article = self._load_draft_article_from_local_fs(topic=topic, - draft_article_path=draft_article_path, - url_to_info_path=url_to_info_path) - self.run_article_polishing_module(draft_article=draft_article, remove_duplicate=remove_duplicate) + draft_article_path = os.path.join( + self.article_output_dir, "storm_gen_article.txt" + ) + url_to_info_path = os.path.join( + self.article_output_dir, "url_to_info.json" + ) + draft_article = self._load_draft_article_from_local_fs( + topic=topic, + draft_article_path=draft_article_path, + url_to_info_path=url_to_info_path, + ) + self.run_article_polishing_module( + draft_article=draft_article, remove_duplicate=remove_duplicate + ) diff --git a/knowledge_storm/storm_wiki/modules/article_generation.py b/knowledge_storm/storm_wiki/modules/article_generation.py index a114b3ec..2e711465 100644 --- a/knowledge_storm/storm_wiki/modules/article_generation.py +++ b/knowledge_storm/storm_wiki/modules/article_generation.py @@ -15,35 +15,48 @@ class StormArticleGenerationModule(ArticleGenerationModule): """ The interface for article generation stage. Given topic, collected information from - knowledge curation stage, generated outline from outline generation stage, + knowledge curation stage, generated outline from outline generation stage, """ - def __init__(self, - article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel], - retrieve_top_k: int = 5, - max_thread_num: int = 10): + def __init__( + self, + article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel], + retrieve_top_k: int = 5, + max_thread_num: int = 10, + ): super().__init__() self.retrieve_top_k = retrieve_top_k self.article_gen_lm = article_gen_lm self.max_thread_num = max_thread_num self.section_gen = ConvToSection(engine=self.article_gen_lm) - def generate_section(self, topic, section_name, information_table, section_outline, section_query): + def generate_section( + self, topic, section_name, information_table, section_outline, section_query + ): collected_info: List[StormInformation] = [] if information_table is not None: - collected_info = information_table.retrieve_information(queries=section_query, - search_top_k=self.retrieve_top_k) - output = self.section_gen(topic=topic, - outline=section_outline, - section=section_name, - collected_info=collected_info) - return {"section_name": section_name, "section_content": output.section, "collected_info": collected_info} - - def generate_article(self, - topic: str, - information_table: StormInformationTable, - article_with_outline: StormArticle, - callback_handler: BaseCallbackHandler = None) -> StormArticle: + collected_info = information_table.retrieve_information( + queries=section_query, search_top_k=self.retrieve_top_k + ) + output = self.section_gen( + topic=topic, + outline=section_outline, + section=section_name, + collected_info=collected_info, + ) + return { + "section_name": section_name, + "section_content": output.section, + "collected_info": collected_info, + } + + def generate_article( + self, + topic: str, + information_table: StormInformationTable, + article_with_outline: StormArticle, + callback_handler: BaseCallbackHandler = None, + ) -> StormArticle: """ Generate article for the topic based on the information table and article outline. @@ -63,35 +76,48 @@ def generate_article(self, section_output_dict_collection = [] if len(sections_to_write) == 0: - logging.error(f'No outline for {topic}. Will directly search with the topic.') + logging.error( + f"No outline for {topic}. Will directly search with the topic." + ) section_output_dict = self.generate_section( topic=topic, section_name=topic, information_table=information_table, section_outline="", - section_query=[topic] + section_query=[topic], ) section_output_dict_collection = [section_output_dict] else: - with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_num) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=self.max_thread_num + ) as executor: future_to_sec_title = {} for section_title in sections_to_write: # We don't want to write a separate introduction section. - if section_title.lower().strip() == 'introduction': + if section_title.lower().strip() == "introduction": continue # We don't want to write a separate conclusion section. if section_title.lower().strip().startswith( - 'conclusion') or section_title.lower().strip().startswith('summary'): + "conclusion" + ) or section_title.lower().strip().startswith("summary"): continue - section_query = article_with_outline.get_outline_as_list(root_section_name=section_title, - add_hashtags=False) + section_query = article_with_outline.get_outline_as_list( + root_section_name=section_title, add_hashtags=False + ) queries_with_hashtags = article_with_outline.get_outline_as_list( - root_section_name=section_title, add_hashtags=True) + root_section_name=section_title, add_hashtags=True + ) section_outline = "\n".join(queries_with_hashtags) future_to_sec_title[ - executor.submit(self.generate_section, - topic, section_title, information_table, section_outline, section_query) + executor.submit( + self.generate_section, + topic, + section_title, + information_table, + section_outline, + section_query, + ) ] = section_title for future in as_completed(future_to_sec_title): @@ -99,9 +125,11 @@ def generate_article(self, article = copy.deepcopy(article_with_outline) for section_output_dict in section_output_dict_collection: - article.update_section(parent_section_name=topic, - current_section_content=section_output_dict["section_content"], - current_section_info_list=section_output_dict["collected_info"]) + article.update_section( + parent_section_name=topic, + current_section_content=section_output_dict["section_content"], + current_section_info_list=section_output_dict["collected_info"], + ) article.post_processing() return article @@ -114,17 +142,24 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.write_section = dspy.Predict(WriteSection) self.engine = engine - def forward(self, topic: str, outline: str, section: str, collected_info: List[StormInformation]): - info = '' + def forward( + self, + topic: str, + outline: str, + section: str, + collected_info: List[StormInformation], + ): + info = "" for idx, storm_info in enumerate(collected_info): - info += f'[{idx + 1}]\n' + '\n'.join(storm_info.snippets) - info += '\n\n' + info += f"[{idx + 1}]\n" + "\n".join(storm_info.snippets) + info += "\n\n" info = ArticleTextProcessing.limit_word_count_preserve_newline(info, 1500) with dspy.settings.context(lm=self.engine): section = ArticleTextProcessing.clean_up_section( - self.write_section(topic=topic, info=info, section=section).output) + self.write_section(topic=topic, info=info, section=section).output + ) return dspy.Prediction(section=section) @@ -132,9 +167,9 @@ def forward(self, topic: str, outline: str, section: str, collected_info: List[S class WriteSection(dspy.Signature): """Write a Wikipedia section based on the collected information. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end. """ info = dspy.InputField(prefix="The collected information:\n", format=str) @@ -142,5 +177,5 @@ class WriteSection(dspy.Signature): section = dspy.InputField(prefix="The section you need to write: ", format=str) output = dspy.OutputField( prefix="Write the section with proper inline citations (Start your writing with # section title. Don't include the page title or try to write other sections):\n", - format=str + format=str, ) diff --git a/knowledge_storm/storm_wiki/modules/article_polish.py b/knowledge_storm/storm_wiki/modules/article_polish.py index b70bb834..fb85b0f3 100644 --- a/knowledge_storm/storm_wiki/modules/article_polish.py +++ b/knowledge_storm/storm_wiki/modules/article_polish.py @@ -14,21 +14,21 @@ class StormArticlePolishingModule(ArticlePolishingModule): knowledge curation stage, generated outline from outline generation stage. """ - def __init__(self, - article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + def __init__( + self, + article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + ): self.article_gen_lm = article_gen_lm self.article_polish_lm = article_polish_lm self.polish_page = PolishPageModule( - write_lead_engine=self.article_gen_lm, - polish_engine=self.article_polish_lm + write_lead_engine=self.article_gen_lm, polish_engine=self.article_polish_lm ) - def polish_article(self, - topic: str, - draft_article: StormArticle, - remove_duplicate: bool = False) -> StormArticle: + def polish_article( + self, topic: str, draft_article: StormArticle, remove_duplicate: bool = False + ) -> StormArticle: """ Polish article. @@ -39,10 +39,14 @@ def polish_article(self, """ article_text = draft_article.to_string() - polish_result = self.polish_page(topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate) + polish_result = self.polish_page( + topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate + ) lead_section = f"# summary\n{polish_result.lead_section}" - polished_article = '\n\n'.join([lead_section, polish_result.page]) - polished_article_dict = ArticleTextProcessing.parse_article_into_dict(polished_article) + polished_article = "\n\n".join([lead_section, polish_result.page]) + polished_article_dict = ArticleTextProcessing.parse_article_into_dict( + polished_article + ) polished_article = copy.deepcopy(draft_article) polished_article.insert_or_create_section(article_dict=polished_article_dict) polished_article.post_processing() @@ -51,9 +55,10 @@ def polish_article(self, class WriteLeadSection(dspy.Signature): """Write a lead section for the given Wikipedia page with the following guidelines: - 1. The lead should stand on its own as a concise overview of the article's topic. It should identify the topic, establish context, explain why the topic is notable, and summarize the most important points, including any prominent controversies. - 2. The lead section should be concise and contain no more than four well-composed paragraphs. - 3. The lead section should be carefully sourced as appropriate. Add inline citations (e.g., "Washington, D.C., is the capital of the United States.[1][3].") where necessary.""" + 1. The lead should stand on its own as a concise overview of the article's topic. It should identify the topic, establish context, explain why the topic is notable, and summarize the most important points, including any prominent controversies. + 2. The lead section should be concise and contain no more than four well-composed paragraphs. + 3. The lead section should be carefully sourced as appropriate. Add inline citations (e.g., "Washington, D.C., is the capital of the United States.[1][3].") where necessary. + """ topic = dspy.InputField(prefix="The topic of the page: ", format=str) draft_page = dspy.InputField(prefix="The draft page:\n", format=str) @@ -68,8 +73,11 @@ class PolishPage(dspy.Signature): class PolishPageModule(dspy.Module): - def __init__(self, write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + def __init__( + self, + write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + ): super().__init__() self.write_lead_engine = write_lead_engine self.polish_engine = polish_engine @@ -78,7 +86,9 @@ def __init__(self, write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], def forward(self, topic: str, draft_page: str, polish_whole_page: bool = True): with dspy.settings.context(lm=self.write_lead_engine): - lead_section = self.write_lead(topic=topic, draft_page=draft_page).lead_section + lead_section = self.write_lead( + topic=topic, draft_page=draft_page + ).lead_section if "The lead section:" in lead_section: lead_section = lead_section.split("The lead section:")[1].strip() if polish_whole_page: diff --git a/knowledge_storm/storm_wiki/modules/knowledge_curation.py b/knowledge_storm/storm_wiki/modules/knowledge_curation.py index 8e881c65..bde27678 100644 --- a/knowledge_storm/storm_wiki/modules/knowledge_curation.py +++ b/knowledge_storm/storm_wiki/modules/knowledge_curation.py @@ -25,20 +25,32 @@ class ConvSimulator(dspy.Module): """Simulate a conversation between a Wikipedia writer with specific persona and an expert.""" - def __init__(self, topic_expert_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - question_asker_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - retriever: Retriever, max_search_queries_per_turn: int, search_top_k: int, max_turn: int): + def __init__( + self, + topic_expert_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + question_asker_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + retriever: Retriever, + max_search_queries_per_turn: int, + search_top_k: int, + max_turn: int, + ): super().__init__() self.wiki_writer = WikiWriter(engine=question_asker_engine) self.topic_expert = TopicExpert( engine=topic_expert_engine, max_search_queries=max_search_queries_per_turn, search_top_k=search_top_k, - retriever=retriever + retriever=retriever, ) self.max_turn = max_turn - def forward(self, topic: str, persona: str, ground_truth_url: str, callback_handler: BaseCallbackHandler): + def forward( + self, + topic: str, + persona: str, + ground_truth_url: str, + callback_handler: BaseCallbackHandler, + ): """ topic: The topic to research. persona: The persona of the Wikipedia writer. @@ -46,18 +58,22 @@ def forward(self, topic: str, persona: str, ground_truth_url: str, callback_hand """ dlg_history: List[DialogueTurn] = [] for _ in range(self.max_turn): - user_utterance = self.wiki_writer(topic=topic, persona=persona, dialogue_turns=dlg_history).question - if user_utterance == '': - logging.error('Simulated Wikipedia writer utterance is empty.') + user_utterance = self.wiki_writer( + topic=topic, persona=persona, dialogue_turns=dlg_history + ).question + if user_utterance == "": + logging.error("Simulated Wikipedia writer utterance is empty.") break - if user_utterance.startswith('Thank you so much for your help!'): + if user_utterance.startswith("Thank you so much for your help!"): break - expert_output = self.topic_expert(topic=topic, question=user_utterance, ground_truth_url=ground_truth_url) + expert_output = self.topic_expert( + topic=topic, question=user_utterance, ground_truth_url=ground_truth_url + ) dlg_turn = DialogueTurn( agent_utterance=expert_output.answer, user_utterance=user_utterance, search_queries=expert_output.queries, - search_results=expert_output.searched_results + search_results=expert_output.searched_results, ) dlg_history.append(dlg_turn) callback_handler.on_dialogue_turn_end(dlg_turn=dlg_turn) @@ -76,22 +92,35 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.ask_question = dspy.ChainOfThought(AskQuestion) self.engine = engine - def forward(self, topic: str, persona: str, dialogue_turns: List[DialogueTurn], draft_page=None): + def forward( + self, + topic: str, + persona: str, + dialogue_turns: List[DialogueTurn], + draft_page=None, + ): conv = [] for turn in dialogue_turns[:-4]: - conv.append(f'You: {turn.user_utterance}\nExpert: Omit the answer here due to space limit.') + conv.append( + f"You: {turn.user_utterance}\nExpert: Omit the answer here due to space limit." + ) for turn in dialogue_turns[-4:]: conv.append( - f'You: {turn.user_utterance}\nExpert: {ArticleTextProcessing.remove_citations(turn.agent_utterance)}') - conv = '\n'.join(conv) - conv = conv.strip() or 'N/A' + f"You: {turn.user_utterance}\nExpert: {ArticleTextProcessing.remove_citations(turn.agent_utterance)}" + ) + conv = "\n".join(conv) + conv = conv.strip() or "N/A" conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 2500) with dspy.settings.context(lm=self.engine): if persona is not None and len(persona.strip()) > 0: - question = self.ask_question_with_persona(topic=topic, persona=persona, conv=conv).question + question = self.ask_question_with_persona( + topic=topic, persona=persona, conv=conv + ).question else: - question = self.ask_question(topic=topic, persona=persona, conv=conv).question + question = self.ask_question( + topic=topic, persona=persona, conv=conv + ).question return dspy.Prediction(question=question) @@ -99,10 +128,11 @@ def forward(self, topic: str, persona: str, dialogue_turns: List[DialogueTurn], class AskQuestion(dspy.Signature): """You are an experienced Wikipedia writer. You are chatting with an expert to get information for the topic you want to contribute. Ask good questions to get more useful information relevant to the topic. When you have no more question to ask, say "Thank you so much for your help!" to end the conversation. - Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write.""" + Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write. + """ - topic = dspy.InputField(prefix='Topic you want to write: ', format=str) - conv = dspy.InputField(prefix='Conversation history:\n', format=str) + topic = dspy.InputField(prefix="Topic you want to write: ", format=str) + conv = dspy.InputField(prefix="Conversation history:\n", format=str) question = dspy.OutputField(format=str) @@ -110,38 +140,41 @@ class AskQuestionWithPersona(dspy.Signature): """You are an experienced Wikipedia writer and want to edit a specific page. Besides your identity as a Wikipedia writer, you have specific focus when researching the topic. Now, you are chatting with an expert to get information. Ask good questions to get more useful information. When you have no more question to ask, say "Thank you so much for your help!" to end the conversation. - Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write.""" + Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write. + """ - topic = dspy.InputField(prefix='Topic you want to write: ', format=str) - persona = dspy.InputField(prefix='Your persona besides being a Wikipedia writer: ', format=str) - conv = dspy.InputField(prefix='Conversation history:\n', format=str) + topic = dspy.InputField(prefix="Topic you want to write: ", format=str) + persona = dspy.InputField( + prefix="Your persona besides being a Wikipedia writer: ", format=str + ) + conv = dspy.InputField(prefix="Conversation history:\n", format=str) question = dspy.OutputField(format=str) class QuestionToQuery(dspy.Signature): """You want to answer the question using Google search. What do you type in the search box? - Write the queries you will use in the following format: - - query 1 - - query 2 - ... - - query n""" - - topic = dspy.InputField(prefix='Topic you are discussing about: ', format=str) - question = dspy.InputField(prefix='Question you want to answer: ', format=str) + Write the queries you will use in the following format: + - query 1 + - query 2 + ... + - query n""" + + topic = dspy.InputField(prefix="Topic you are discussing about: ", format=str) + question = dspy.InputField(prefix="Question you want to answer: ", format=str) queries = dspy.OutputField(format=str) class AnswerQuestion(dspy.Signature): """You are an expert who can use information effectively. You are chatting with a Wikipedia writer who wants to write a Wikipedia page on topic you know. You have gathered the related information and will now use the information to form a response. - Make your response as informative as possible and make sure every sentence is supported by the gathered information. If [Gathered information] is not related to he [Topic] and [Question], output "Sorry, I don't have enough information to answer the question.".""" + Make your response as informative as possible and make sure every sentence is supported by the gathered information. If [Gathered information] is not related to he [Topic] and [Question], output "Sorry, I don't have enough information to answer the question.". + """ - topic = dspy.InputField(prefix='Topic you are discussing about:', format=str) - conv = dspy.InputField(prefix='Question:\n', format=str) - info = dspy.InputField( - prefix='Gathered information:\n', format=str) + topic = dspy.InputField(prefix="Topic you are discussing about:", format=str) + conv = dspy.InputField(prefix="Question:\n", format=str) + info = dspy.InputField(prefix="Gathered information:\n", format=str) answer = dspy.OutputField( - prefix='Now give your response. (Try to use as many different sources as possible and add do not hallucinate.)\n', - format=str + prefix="Now give your response. (Try to use as many different sources as possible and add do not hallucinate.)\n", + format=str, ) @@ -153,8 +186,13 @@ class TopicExpert(dspy.Module): 4. Generate an answer using the retrieved information. """ - def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - max_search_queries: int, search_top_k: int, retriever: Retriever): + def __init__( + self, + engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + max_search_queries: int, + search_top_k: int, + retriever: Retriever, + ): super().__init__() self.generate_queries = dspy.Predict(QuestionToQuery) self.retriever = retriever @@ -168,31 +206,43 @@ def forward(self, topic: str, question: str, ground_truth_url: str): with dspy.settings.context(lm=self.engine): # Identify: Break down question into queries. queries = self.generate_queries(topic=topic, question=question).queries - queries = [q.replace('-', '').strip().strip('"').strip('"').strip() for q in queries.split('\n')] - queries = queries[:self.max_search_queries] + queries = [ + q.replace("-", "").strip().strip('"').strip('"').strip() + for q in queries.split("\n") + ] + queries = queries[: self.max_search_queries] # Search - searched_results: List[StormInformation] = self.retriever.retrieve(list(set(queries)), - exclude_urls=[ground_truth_url]) + searched_results: List[StormInformation] = self.retriever.retrieve( + list(set(queries)), exclude_urls=[ground_truth_url] + ) if len(searched_results) > 0: # Evaluate: Simplify this part by directly using the top 1 snippet. - info = '' + info = "" for n, r in enumerate(searched_results): - info += '\n'.join(f'[{n + 1}]: {s}' for s in r.snippets[:1]) - info += '\n\n' + info += "\n".join(f"[{n + 1}]: {s}" for s in r.snippets[:1]) + info += "\n\n" - info = ArticleTextProcessing.limit_word_count_preserve_newline(info, 1000) + info = ArticleTextProcessing.limit_word_count_preserve_newline( + info, 1000 + ) try: - answer = self.answer_question(topic=topic, conv=question, info=info).answer - answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(answer) + answer = self.answer_question( + topic=topic, conv=question, info=info + ).answer + answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( + answer + ) except Exception as e: - logging.error(f'Error occurs when generating answer: {e}') - answer = 'Sorry, I cannot answer this question. Please ask another question.' + logging.error(f"Error occurs when generating answer: {e}") + answer = "Sorry, I cannot answer this question. Please ask another question." else: # When no information is found, the expert shouldn't hallucinate. - answer = 'Sorry, I cannot find information for this question. Please ask another question.' + answer = "Sorry, I cannot find information for this question. Please ask another question." - return dspy.Prediction(queries=queries, searched_results=searched_results, answer=answer) + return dspy.Prediction( + queries=queries, searched_results=searched_results, answer=answer + ) class StormKnowledgeCurationModule(KnowledgeCurationModule): @@ -200,15 +250,17 @@ class StormKnowledgeCurationModule(KnowledgeCurationModule): The interface for knowledge curation stage. Given topic, return collected information. """ - def __init__(self, - retriever: Retriever, - persona_generator: Optional[StormPersonaGenerator], - conv_simulator_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - question_asker_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - max_search_queries_per_turn: int, - search_top_k: int, - max_conv_turn: int, - max_thread_num: int): + def __init__( + self, + retriever: Retriever, + persona_generator: Optional[StormPersonaGenerator], + conv_simulator_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + question_asker_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + max_search_queries_per_turn: int, + search_top_k: int, + max_conv_turn: int, + max_thread_num: int, + ): """ Store args and finish initialization. """ @@ -224,14 +276,22 @@ def __init__(self, retriever=retriever, max_search_queries_per_turn=max_search_queries_per_turn, search_top_k=search_top_k, - max_turn=max_conv_turn + max_turn=max_conv_turn, ) def _get_considered_personas(self, topic: str, max_num_persona) -> List[str]: - return self.persona_generator.generate_persona(topic=topic, max_num_persona=max_num_persona) + return self.persona_generator.generate_persona( + topic=topic, max_num_persona=max_num_persona + ) - def _run_conversation(self, conv_simulator, topic, ground_truth_url, considered_personas, - callback_handler: BaseCallbackHandler) -> List[Tuple[str, List[DialogueTurn]]]: + def _run_conversation( + self, + conv_simulator, + topic, + ground_truth_url, + considered_personas, + callback_handler: BaseCallbackHandler, + ) -> List[Tuple[str, List[DialogueTurn]]]: """ Executes multiple conversation simulations concurrently, each with a different persona, and collects their dialog histories. The dialog history of each conversation is cleaned @@ -260,13 +320,16 @@ def run_conv(persona): topic=topic, ground_truth_url=ground_truth_url, persona=persona, - callback_handler=callback_handler + callback_handler=callback_handler, ) max_workers = min(self.max_thread_num, len(considered_personas)) with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_persona = {executor.submit(run_conv, persona): persona for persona in considered_personas} + future_to_persona = { + executor.submit(run_conv, persona): persona + for persona in considered_personas + } if streamlit_connection: # Ensure the logging context is correct when connecting with Streamlit frontend. @@ -276,23 +339,27 @@ def run_conv(persona): for future in as_completed(future_to_persona): persona = future_to_persona[future] conv = future.result() - conversations.append((persona, ArticleTextProcessing.clean_up_citation(conv).dlg_history)) + conversations.append( + (persona, ArticleTextProcessing.clean_up_citation(conv).dlg_history) + ) return conversations - def research(self, - topic: str, - ground_truth_url: str, - callback_handler: BaseCallbackHandler, - max_perspective: int = 0, - disable_perspective: bool = True, - return_conversation_log=False) -> Union[StormInformationTable, Tuple[StormInformationTable, Dict]]: + def research( + self, + topic: str, + ground_truth_url: str, + callback_handler: BaseCallbackHandler, + max_perspective: int = 0, + disable_perspective: bool = True, + return_conversation_log=False, + ) -> Union[StormInformationTable, Tuple[StormInformationTable, Dict]]: """ Curate information and knowledge for the given topic Args: topic: topic of interest in natural language. - + Returns: collected_information: collected information in InformationTable type. """ @@ -303,19 +370,25 @@ def research(self, if disable_perspective: considered_personas = [""] else: - considered_personas = self._get_considered_personas(topic=topic, max_num_persona=max_perspective) + considered_personas = self._get_considered_personas( + topic=topic, max_num_persona=max_perspective + ) callback_handler.on_identify_perspective_end(perspectives=considered_personas) - # run conversation + # run conversation callback_handler.on_information_gathering_start() - conversations = self._run_conversation(conv_simulator=self.conv_simulator, - topic=topic, - ground_truth_url=ground_truth_url, - considered_personas=considered_personas, - callback_handler=callback_handler) + conversations = self._run_conversation( + conv_simulator=self.conv_simulator, + topic=topic, + ground_truth_url=ground_truth_url, + considered_personas=considered_personas, + callback_handler=callback_handler, + ) information_table = StormInformationTable(conversations) callback_handler.on_information_gathering_end() if return_conversation_log: - return information_table, StormInformationTable.construct_log_dict(conversations) + return information_table, StormInformationTable.construct_log_dict( + conversations + ) return information_table diff --git a/knowledge_storm/storm_wiki/modules/outline_generation.py b/knowledge_storm/storm_wiki/modules/outline_generation.py index 1f45b1c2..a96c7978 100644 --- a/knowledge_storm/storm_wiki/modules/outline_generation.py +++ b/knowledge_storm/storm_wiki/modules/outline_generation.py @@ -14,18 +14,19 @@ class StormOutlineGenerationModule(OutlineGenerationModule): curation stage, generate outline for the article. """ - def __init__(self, - outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + def __init__(self, outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): super().__init__() self.outline_gen_lm = outline_gen_lm self.write_outline = WriteOutline(engine=self.outline_gen_lm) - def generate_outline(self, - topic: str, - information_table: StormInformationTable, - old_outline: Optional[StormArticle] = None, - callback_handler: BaseCallbackHandler = None, - return_draft_outline=False) -> Union[StormArticle, Tuple[StormArticle, StormArticle]]: + def generate_outline( + self, + topic: str, + information_table: StormInformationTable, + old_outline: Optional[StormArticle] = None, + callback_handler: BaseCallbackHandler = None, + return_draft_outline=False, + ) -> Union[StormArticle, Tuple[StormArticle, StormArticle]]: """ Generates an outline for an article based on the specified topic and the information gathered during the knowledge curation stage. This method can optionally return both the @@ -34,30 +35,38 @@ def generate_outline(self, Args: topic (str): The topic of the article. information_table (StormInformationTable): The information table containing the collected information. - old_outline (Optional[StormArticle]): An optional previous version of the article outline that can + old_outline (Optional[StormArticle]): An optional previous version of the article outline that can be used for reference or comparison. Defaults to None. - callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger - custom callbacks at various stages of the outline generation process, such as when the information + callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger + custom callbacks at various stages of the outline generation process, such as when the information organization starts. Defaults to None. - return_draft_outline (bool): A flag indicating whether the method should return both the final article - outline and a draft version of the outline. If False, only the final article outline is returned. + return_draft_outline (bool): A flag indicating whether the method should return both the final article + outline and a draft version of the outline. If False, only the final article outline is returned. Defaults to False. Returns: - Union[StormArticle, Tuple[StormArticle, StormArticle]]: Depending on the value of `return_draft_outline`, - this method returns either a single `StormArticle` object containing the final outline or a tuple of - two `StormArticle` objects, the first containing the final outline and the second containing the + Union[StormArticle, Tuple[StormArticle, StormArticle]]: Depending on the value of `return_draft_outline`, + this method returns either a single `StormArticle` object containing the final outline or a tuple of + two `StormArticle` objects, the first containing the final outline and the second containing the draft outline. """ if callback_handler is not None: callback_handler.on_information_organization_start() - concatenated_dialogue_turns = sum([conv for (_, conv) in information_table.conversations], []) - result = self.write_outline(topic=topic, dlg_history=concatenated_dialogue_turns, - callback_handler=callback_handler) - article_with_outline_only = StormArticle.from_outline_str(topic=topic, outline_str=result.outline) - article_with_draft_outline_only = StormArticle.from_outline_str(topic=topic, - outline_str=result.old_outline) + concatenated_dialogue_turns = sum( + [conv for (_, conv) in information_table.conversations], [] + ) + result = self.write_outline( + topic=topic, + dlg_history=concatenated_dialogue_turns, + callback_handler=callback_handler, + ) + article_with_outline_only = StormArticle.from_outline_str( + topic=topic, outline_str=result.outline + ) + article_with_draft_outline_only = StormArticle.from_outline_str( + topic=topic, outline_str=result.old_outline + ) if not return_draft_outline: return article_with_outline_only return article_with_outline_only, article_with_draft_outline_only @@ -72,25 +81,44 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.write_page_outline = dspy.Predict(WritePageOutlineFromConv) self.engine = engine - def forward(self, topic: str, dlg_history, old_outline: Optional[str] = None, - callback_handler: BaseCallbackHandler = None): + def forward( + self, + topic: str, + dlg_history, + old_outline: Optional[str] = None, + callback_handler: BaseCallbackHandler = None, + ): trimmed_dlg_history = [] for turn in dlg_history: - if 'topic you' in turn.agent_utterance.lower() or 'topic you' in turn.user_utterance.lower(): + if ( + "topic you" in turn.agent_utterance.lower() + or "topic you" in turn.user_utterance.lower() + ): continue trimmed_dlg_history.append(turn) - conv = '\n'.join([f'Wikipedia Writer: {turn.user_utterance}\nExpert: {turn.agent_utterance}' for turn in - trimmed_dlg_history]) + conv = "\n".join( + [ + f"Wikipedia Writer: {turn.user_utterance}\nExpert: {turn.agent_utterance}" + for turn in trimmed_dlg_history + ] + ) conv = ArticleTextProcessing.remove_citations(conv) conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 5000) with dspy.settings.context(lm=self.engine): if old_outline is None: - old_outline = ArticleTextProcessing.clean_up_outline(self.draft_page_outline(topic=topic).outline) + old_outline = ArticleTextProcessing.clean_up_outline( + self.draft_page_outline(topic=topic).outline + ) if callback_handler: - callback_handler.on_direct_outline_generation_end(outline=old_outline) + callback_handler.on_direct_outline_generation_end( + outline=old_outline + ) outline = ArticleTextProcessing.clean_up_outline( - self.write_page_outline(topic=topic, old_outline=old_outline, conv=conv).outline) + self.write_page_outline( + topic=topic, old_outline=old_outline, conv=conv + ).outline + ) if callback_handler: callback_handler.on_outline_refinement_end(outline=outline) @@ -99,10 +127,10 @@ def forward(self, topic: str, dlg_history, old_outline: Optional[str] = None, class WritePageOutline(dspy.Signature): """Write an outline for a Wikipedia page. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Do not include other information. - 3. Do not include topic name itself in the outline. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Do not include other information. + 3. Do not include topic name itself in the outline. """ topic = dspy.InputField(prefix="The topic you want to write: ", format=str) @@ -124,10 +152,10 @@ def forward(self, topic: str): class WritePageOutlineFromConv(dspy.Signature): """Improve an outline for a Wikipedia page. You already have a draft outline that covers the general information. Now you want to improve it based on the information learned from an information-seeking conversation to make it more informative. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Do not include other information. - 3. Do not include topic name itself in the outline. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Do not include other information. + 3. Do not include topic name itself in the outline. """ topic = dspy.InputField(prefix="The topic you want to write: ", format=str) @@ -135,5 +163,5 @@ class WritePageOutlineFromConv(dspy.Signature): old_outline = dspy.OutputField(prefix="Current outline:\n", format=str) outline = dspy.OutputField( prefix='Write the Wikipedia page outline (Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, ...):\n', - format=str + format=str, ) diff --git a/knowledge_storm/storm_wiki/modules/persona_generator.py b/knowledge_storm/storm_wiki/modules/persona_generator.py index 5150e31b..c51dc0cc 100644 --- a/knowledge_storm/storm_wiki/modules/persona_generator.py +++ b/knowledge_storm/storm_wiki/modules/persona_generator.py @@ -11,19 +11,27 @@ def get_wiki_page_title_and_toc(url): """Get the main title and table of contents from an url of a Wikipedia page.""" response = requests.get(url) - soup = BeautifulSoup(response.content, 'html.parser') + soup = BeautifulSoup(response.content, "html.parser") # Get the main title from the first h1 tag - main_title = soup.find('h1').text.replace('[edit]', '').strip().replace('\xa0', ' ') + main_title = soup.find("h1").text.replace("[edit]", "").strip().replace("\xa0", " ") toc = "" levels = [] - excluded_sections = {'Contents', 'See also', 'Notes', 'References', 'External links'} + excluded_sections = { + "Contents", + "See also", + "Notes", + "References", + "External links", + } # Start processing from h2 to exclude the main title from TOC - for header in soup.find_all(['h2', 'h3', "h4", "h5", "h6"]): - level = int(header.name[1]) # Extract the numeric part of the header tag (e.g., '2' from 'h2') - section_title = header.text.replace('[edit]', '').strip().replace('\xa0', ' ') + for header in soup.find_all(["h2", "h3", "h4", "h5", "h6"]): + level = int( + header.name[1] + ) # Extract the numeric part of the header tag (e.g., '2' from 'h2') + section_title = header.text.replace("[edit]", "").strip().replace("\xa0", " ") if section_title in excluded_sections: continue @@ -39,9 +47,9 @@ def get_wiki_page_title_and_toc(url): class FindRelatedTopic(dspy.Signature): """I'm writing a Wikipedia page for a topic mentioned below. Please identify and recommend some Wikipedia pages on closely related subjects. I'm looking for examples that provide insights into interesting aspects commonly associated with this topic, or examples that help me understand the typical content and structure included in Wikipedia pages for similar topics. - Please list the urls in separate lines.""" + Please list the urls in separate lines.""" - topic = dspy.InputField(prefix='Topic of interest:', format=str) + topic = dspy.InputField(prefix="Topic of interest:", format=str) related_topics = dspy.OutputField(format=str) @@ -50,8 +58,10 @@ class GenPersona(dspy.Signature): Give your answer in the following format: 1. short summary of editor 1: description\n2. short summary of editor 2: description\n... """ - topic = dspy.InputField(prefix='Topic of interest:', format=str) - examples = dspy.InputField(prefix='Wiki page outlines of related topics for inspiration:\n', format=str) + topic = dspy.InputField(prefix="Topic of interest:", format=str) + examples = dspy.InputField( + prefix="Wiki page outlines of related topics for inspiration:\n", format=str + ) personas = dspy.OutputField(format=str) @@ -69,38 +79,44 @@ def forward(self, topic: str, draft=None): # Get section names from wiki pages of relevant topics for inspiration. related_topics = self.find_related_topic(topic=topic).related_topics urls = [] - for s in related_topics.split('\n'): - if 'http' in s: - urls.append(s[s.find('http'):]) + for s in related_topics.split("\n"): + if "http" in s: + urls.append(s[s.find("http") :]) examples = [] for url in urls: try: title, toc = get_wiki_page_title_and_toc(url) - examples.append(f'Title: {title}\nTable of Contents: {toc}') + examples.append(f"Title: {title}\nTable of Contents: {toc}") except Exception as e: - logging.error(f'Error occurs when processing {url}: {e}') + logging.error(f"Error occurs when processing {url}: {e}") continue if len(examples) == 0: - examples.append('N/A') - gen_persona_output = self.gen_persona(topic=topic, examples='\n----------\n'.join(examples)).personas + examples.append("N/A") + gen_persona_output = self.gen_persona( + topic=topic, examples="\n----------\n".join(examples) + ).personas personas = [] - for s in gen_persona_output.split('\n'): - match = re.search(r'\d+\.\s*(.*)', s) + for s in gen_persona_output.split("\n"): + match = re.search(r"\d+\.\s*(.*)", s) if match: personas.append(match.group(1)) sorted_personas = personas - return dspy.Prediction(personas=personas, raw_personas_output=sorted_personas, related_topics=related_topics) + return dspy.Prediction( + personas=personas, + raw_personas_output=sorted_personas, + related_topics=related_topics, + ) -class StormPersonaGenerator(): +class StormPersonaGenerator: """ A generator class for creating personas based on a given topic. - This class uses an underlying engine to generate personas tailored to the specified topic. - The generator integrates with a `CreateWriterWithPersona` instance to create diverse personas, + This class uses an underlying engine to generate personas tailored to the specified topic. + The generator integrates with a `CreateWriterWithPersona` instance to create diverse personas, including a default 'Basic fact writer' persona. Attributes: @@ -133,6 +149,6 @@ def generate_persona(self, topic: str, max_num_persona: int = 3) -> List[str]: and up to `max_num_persona` additional personas generated based on the topic. """ personas = self.create_writer_with_persona(topic=topic) - default_persona = 'Basic fact writer: Basic fact writer focusing on broadly covering the basic facts about the topic.' + default_persona = "Basic fact writer: Basic fact writer focusing on broadly covering the basic facts about the topic." considered_personas = [default_persona] + personas.personas[:max_num_persona] return considered_personas diff --git a/knowledge_storm/storm_wiki/modules/retriever.py b/knowledge_storm/storm_wiki/modules/retriever.py index 179ae99b..85df63ec 100644 --- a/knowledge_storm/storm_wiki/modules/retriever.py +++ b/knowledge_storm/storm_wiki/modules/retriever.py @@ -149,7 +149,8 @@ "WordPress.com", "Worldometer", "YouTube", - "ZDNet"} + "ZDNet", +} DEPRECATED = { "Al_Mayadeen", "ANNA_News", @@ -197,7 +198,7 @@ "VDARE", "Voltaire_Network", "WorldNetDaily", - "Zero_Hedge" + "Zero_Hedge", } BLACKLISTED = { "Advameg", @@ -218,7 +219,7 @@ "The_Points_Guy_(sponsored_content)", "Swarajya", "Veterans_Today", - "ZoomInfo" + "ZoomInfo", } @@ -237,14 +238,20 @@ class StormRetriever(Retriever): def __init__(self, rm: dspy.Retrieve, k=3): super().__init__(search_top_k=k) self._rm = rm - if hasattr(rm, 'is_valid_source'): + if hasattr(rm, "is_valid_source"): rm.is_valid_source = is_valid_wikipedia_source - def retrieve(self, query: Union[str, List[str]], exclude_urls: List[str] = []) -> List[Information]: - retrieved_data_list = self._rm(query_or_queries=query, exclude_urls=exclude_urls) + def retrieve( + self, query: Union[str, List[str]], exclude_urls: List[str] = [] + ) -> List[Information]: + retrieved_data_list = self._rm( + query_or_queries=query, exclude_urls=exclude_urls + ) for data in retrieved_data_list: - for i in range(len(data['snippets'])): + for i in range(len(data["snippets"])): # STORM generate the article with citations. We do not consider multi-hop citations. # Remove citations in the source to avoid confusion. - data['snippets'][i] = ArticleTextProcessing.remove_citations(data['snippets'][i]) + data["snippets"][i] = ArticleTextProcessing.remove_citations( + data["snippets"][i] + ) return [StormInformation.from_dict(data) for data in retrieved_data_list] diff --git a/knowledge_storm/storm_wiki/modules/storm_dataclass.py b/knowledge_storm/storm_wiki/modules/storm_dataclass.py index 4f54ec46..43826ecc 100644 --- a/knowledge_storm/storm_wiki/modules/storm_dataclass.py +++ b/knowledge_storm/storm_wiki/modules/storm_dataclass.py @@ -51,22 +51,29 @@ def from_dict(cls, info_dict): Returns: StormInformation: An instance of StormInformation. """ - return cls(info_dict['url'], info_dict['description'], info_dict['snippets'], info_dict['title']) + return cls( + info_dict["url"], + info_dict["description"], + info_dict["snippets"], + info_dict["title"], + ) def to_dict(self): - return {"url": self.uuid, - "description": self.description, - "snippets": self.snippets, - "title": self.title} + return { + "url": self.uuid, + "description": self.description, + "snippets": self.snippets, + "title": self.title, + } class DialogueTurn: def __init__( - self, - agent_utterance: str = None, - user_utterance: str = None, - search_queries: Optional[List[str]] = None, - search_results: Optional[List[Union[StormInformation, Dict]]] = None + self, + agent_utterance: str = None, + user_utterance: str = None, + search_queries: Optional[List[str]] = None, + search_results: Optional[List[Union[StormInformation, Dict]]] = None, ): self.agent_utterance = agent_utterance self.user_utterance = user_utterance @@ -76,7 +83,9 @@ def __init__( if self.search_results: for idx in range(len(self.search_results)): if type(self.search_results[idx]) == dict: - self.search_results[idx] = StormInformation.from_dict(self.search_results[idx]) + self.search_results[idx] = StormInformation.from_dict( + self.search_results[idx] + ) def log(self): """ @@ -85,10 +94,10 @@ def log(self): return OrderedDict( { - 'agent_utterance': self.agent_utterance, - 'user_utterance': self.user_utterance, - 'search_queries': self.search_queries, - 'search_results': [data.to_dict() for data in self.search_results], + "agent_utterance": self.agent_utterance, + "user_utterance": self.user_utterance, + "search_queries": self.search_queries, + "search_results": [data.to_dict() for data in self.search_results], } ) @@ -98,7 +107,7 @@ class StormInformationTable(InformationTable): The InformationTable class serves as data class to store the information collected during KnowledgeCuration stage. - Create subclass to incorporate more information as needed. For example, + Create subclass to incorporate more information as needed. For example, in STORM paper https://arxiv.org/pdf/2402.14207.pdf, additional information would be perspective guided dialogue history. """ @@ -106,13 +115,17 @@ class StormInformationTable(InformationTable): def __init__(self, conversations=List[Tuple[str, List[DialogueTurn]]]): super().__init__() self.conversations = conversations - self.url_to_info: Dict[str, StormInformation] = StormInformationTable.construct_url_to_info(self.conversations) + self.url_to_info: Dict[str, StormInformation] = ( + StormInformationTable.construct_url_to_info(self.conversations) + ) @staticmethod - def construct_url_to_info(conversations: List[Tuple[str, List[DialogueTurn]]]) -> Dict[str, StormInformation]: + def construct_url_to_info( + conversations: List[Tuple[str, List[DialogueTurn]]] + ) -> Dict[str, StormInformation]: url_to_info = {} - for (persona, conv) in conversations: + for persona, conv in conversations: for turn in conv: for storm_info in turn.search_results: if storm_info.url in url_to_info: @@ -124,14 +137,13 @@ def construct_url_to_info(conversations: List[Tuple[str, List[DialogueTurn]]]) - return url_to_info @staticmethod - def construct_log_dict(conversations: List[Tuple[str, List[DialogueTurn]]]) -> List[Dict[str, Union[str, Any]]]: + def construct_log_dict( + conversations: List[Tuple[str, List[DialogueTurn]]] + ) -> List[Dict[str, Union[str, Any]]]: conversation_log = [] - for (persona, conv) in conversations: + for persona, conv in conversations: conversation_log.append( - { - 'perspective': persona, - 'dlg_turns': [turn.log() for turn in conv] - } + {"perspective": persona, "dlg_turns": [turn.log() for turn in conv]} ) return conversation_log @@ -146,22 +158,26 @@ def from_conversation_log_file(cls, path): conversation_log_data = FileIOHelper.load_json(path) conversations = [] for item in conversation_log_data: - dialogue_turns = [DialogueTurn(**turn) for turn in item['dlg_turns']] - persona = item['perspective'] + dialogue_turns = [DialogueTurn(**turn) for turn in item["dlg_turns"]] + persona = item["perspective"] conversations.append((persona, dialogue_turns)) return cls(conversations) def prepare_table_for_retrieval(self): - self.encoder = SentenceTransformer('paraphrase-MiniLM-L6-v2') + self.encoder = SentenceTransformer("paraphrase-MiniLM-L6-v2") self.collected_urls = [] self.collected_snippets = [] for url, information in self.url_to_info.items(): for snippet in information.snippets: self.collected_urls.append(url) self.collected_snippets.append(snippet) - self.encoded_snippets = self.encoder.encode(self.collected_snippets, show_progress_bar=False) + self.encoded_snippets = self.encoder.encode( + self.collected_snippets, show_progress_bar=False + ) - def retrieve_information(self, queries: Union[List[str], str], search_top_k) -> List[StormInformation]: + def retrieve_information( + self, queries: Union[List[str], str], search_top_k + ) -> List[StormInformation]: selected_urls = [] selected_snippets = [] if type(queries) is str: @@ -191,14 +207,13 @@ def retrieve_information(self, queries: Union[List[str], str], search_top_k) -> class StormArticle(Article): def __init__(self, topic_name): super().__init__(topic_name=topic_name) - self.reference = { - "url_to_unified_index": {}, - "url_to_info": {} - } + self.reference = {"url_to_unified_index": {}, "url_to_info": {}} - def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]: + def find_section( + self, node: ArticleSectionNode, name: str + ) -> Optional[ArticleSectionNode]: """ - Return the node of the section given the section name. + Return the node of the section given the section name. Args: node: the node as the root to find. @@ -215,17 +230,18 @@ def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleS return result return None - def _merge_new_info_to_references(self, new_info_list: List[StormInformation], index_to_keep=None) -> Dict[ - int, int]: + def _merge_new_info_to_references( + self, new_info_list: List[StormInformation], index_to_keep=None + ) -> Dict[int, int]: """ Merges new storm information into existing references and updates the citation index mapping. Args: - new_info_list (List[StormInformation]): A list of dictionaries representing new storm information. + new_info_list (List[StormInformation]): A list of dictionaries representing new storm information. index_to_keep (List[int]): A list of index of the new_info_list to keep. If none, keep all. Returns: - Dict[int, int]: A dictionary mapping the index of each storm information piece in the input list + Dict[int, int]: A dictionary mapping the index of each storm information piece in the input list to its unified citation index in the references. """ citation_idx_mapping = {} @@ -234,20 +250,32 @@ def _merge_new_info_to_references(self, new_info_list: List[StormInformation], i continue url = storm_info.url if url not in self.reference["url_to_unified_index"]: - self.reference["url_to_unified_index"][url] = len( - self.reference["url_to_unified_index"]) + 1 # The citation index starts from 1. + self.reference["url_to_unified_index"][url] = ( + len(self.reference["url_to_unified_index"]) + 1 + ) # The citation index starts from 1. self.reference["url_to_info"][url] = storm_info else: existing_snippets = self.reference["url_to_info"][url].snippets existing_snippets.extend(storm_info.snippets) - self.reference["url_to_info"][url].snippets = list(set(existing_snippets)) + self.reference["url_to_info"][url].snippets = list( + set(existing_snippets) + ) citation_idx_mapping[idx + 1] = self.reference["url_to_unified_index"][ - url] # The citation index starts from 1. + url + ] # The citation index starts from 1. return citation_idx_mapping - def insert_or_create_section(self, article_dict: Dict[str, Dict], parent_section_name: str = None, - trim_children=False): - parent_node = self.root if parent_section_name is None else self.find_section(self.root, parent_section_name) + def insert_or_create_section( + self, + article_dict: Dict[str, Dict], + parent_section_name: str = None, + trim_children=False, + ): + parent_node = ( + self.root + if parent_section_name is None + else self.find_section(self.root, parent_section_name) + ) if trim_children: section_names = set(article_dict.keys()) @@ -258,56 +286,83 @@ def insert_or_create_section(self, article_dict: Dict[str, Dict], parent_section for section_name, content_dict in article_dict.items(): current_section_node = self.find_section(parent_node, section_name) if current_section_node is None: - current_section_node = ArticleSectionNode(section_name=section_name, - content=content_dict["content"].strip()) - insert_to_front = parent_node.section_name == self.root.section_name and current_section_node.section_name == "summary" - parent_node.add_child(current_section_node, insert_to_front=insert_to_front) + current_section_node = ArticleSectionNode( + section_name=section_name, content=content_dict["content"].strip() + ) + insert_to_front = ( + parent_node.section_name == self.root.section_name + and current_section_node.section_name == "summary" + ) + parent_node.add_child( + current_section_node, insert_to_front=insert_to_front + ) else: current_section_node.content = content_dict["content"].strip() - self.insert_or_create_section(article_dict=content_dict["subsections"], parent_section_name=section_name, - trim_children=True) + self.insert_or_create_section( + article_dict=content_dict["subsections"], + parent_section_name=section_name, + trim_children=True, + ) - def update_section(self, - current_section_content: str, - current_section_info_list: List[StormInformation], - parent_section_name: Optional[str] = None) -> Optional[ArticleSectionNode]: + def update_section( + self, + current_section_content: str, + current_section_info_list: List[StormInformation], + parent_section_name: Optional[str] = None, + ) -> Optional[ArticleSectionNode]: """ - Add new section to the article. + Add new section to the article. Args: current_section_name: new section heading name in string format. parent_section_name: under which parent section to add the new one. Default to root. - current_section_content: optional section content. - + current_section_content: optional section content. + Returns: the ArticleSectionNode for current section if successfully created / updated. Otherwise none. """ if current_section_info_list is not None: - references = set([int(x) for x in re.findall(r'\[(\d+)\]', current_section_content)]) + references = set( + [int(x) for x in re.findall(r"\[(\d+)\]", current_section_content)] + ) # for any reference number greater than max number of references, delete the reference if len(references) > 0: max_ref_num = max(references) if max_ref_num > len(current_section_info_list): for i in range(len(current_section_info_list), max_ref_num + 1): - current_section_content = current_section_content.replace(f'[{i}]', '') + current_section_content = current_section_content.replace( + f"[{i}]", "" + ) if i in references: references.remove(i) # for any reference that is not used, trim it from current_section_info_list index_to_keep = [i - 1 for i in references] - citation_mapping = self._merge_new_info_to_references(current_section_info_list, index_to_keep) - current_section_content = ArticleTextProcessing.update_citation_index(current_section_content, - citation_mapping) + citation_mapping = self._merge_new_info_to_references( + current_section_info_list, index_to_keep + ) + current_section_content = ArticleTextProcessing.update_citation_index( + current_section_content, citation_mapping + ) if parent_section_name is None: parent_section_name = self.root.section_name - article_dict = ArticleTextProcessing.parse_article_into_dict(current_section_content) - self.insert_or_create_section(article_dict=article_dict, parent_section_name=parent_section_name, - trim_children=False) + article_dict = ArticleTextProcessing.parse_article_into_dict( + current_section_content + ) + self.insert_or_create_section( + article_dict=article_dict, + parent_section_name=parent_section_name, + trim_children=False, + ) - def get_outline_as_list(self, root_section_name: Optional[str] = None, add_hashtags: bool = False, - include_root: bool = True) -> List[str]: + def get_outline_as_list( + self, + root_section_name: Optional[str] = None, + add_hashtags: bool = False, + include_root: bool = True, + ) -> List[str]: """ Get outline of the article as a list. @@ -320,7 +375,7 @@ def get_outline_as_list(self, root_section_name: Optional[str] = None, add_hasht ###section1.2 ##section2 article.get_outline_as_list("section1") returns [section1, section1.1, section1.2, section2] - + Returns: list of section and subsection names. """ @@ -334,8 +389,14 @@ def get_outline_as_list(self, root_section_name: Optional[str] = None, add_hasht result = [] def preorder_traverse(node, level): - prefix = "#" * level if add_hashtags else "" # Adjust level if excluding root - result.append(f"{prefix} {node.section_name}".strip() if add_hashtags else node.section_name) + prefix = ( + "#" * level if add_hashtags else "" + ) # Adjust level if excluding root + result.append( + f"{prefix} {node.section_name}".strip() + if add_hashtags + else node.section_name + ) for child in node.children: preorder_traverse(child, level + 1) @@ -350,7 +411,7 @@ def preorder_traverse(node, level): def to_string(self) -> str: """ Get outline of the article as a list. - + Returns: list of section and subsection names. """ @@ -376,7 +437,9 @@ def reorder_reference_index(self): def pre_order_find_index(node): if node is not None: if node.content is not None and node.content: - ref_indices.extend(ArticleTextProcessing.parse_citation_indices(node.content)) + ref_indices.extend( + ArticleTextProcessing.parse_citation_indices(node.content) + ) for child in node.children: pre_order_find_index(child) @@ -391,7 +454,9 @@ def pre_order_find_index(node): def pre_order_update_index(node): if node is not None: if node.content is not None and node.content: - node.content = ArticleTextProcessing.update_citation_index(node.content, ref_index_mapping) + node.content = ArticleTextProcessing.update_citation_index( + node.content, ref_index_mapping + ) for child in node.children: pre_order_update_index(child) @@ -442,18 +507,18 @@ def from_outline_str(cls, topic: str, outline_str: str): instance = cls(topic) if lines: - a = lines[0].startswith('#') and lines[0].replace('#', '').strip().lower() + a = lines[0].startswith("#") and lines[0].replace("#", "").strip().lower() b = topic.lower().replace("_", " ") - adjust_level = lines[0].startswith('#') and lines[0].replace('#', - '').strip().lower() == topic.lower().replace( - "_", " ") + adjust_level = lines[0].startswith("#") and lines[0].replace( + "#", "" + ).strip().lower() == topic.lower().replace("_", " ") if adjust_level: lines = lines[1:] node_stack = [(0, instance.root)] # Stack to keep track of (level, node) for line in lines: - level = line.count('#') - adjust_level - section_name = line.replace('#', '').strip() + level = line.count("#") - adjust_level + section_name = line.replace("#", "").strip() if section_name == topic: continue @@ -487,7 +552,9 @@ def from_string(cls, topic_name: str, article_text: str, references: dict): article = cls(topic_name=topic_name) article.insert_or_create_section(article_dict=article_dict) for url in list(references["url_to_info"]): - references["url_to_info"][url] = StormInformation.from_dict(references["url_to_info"][url]) + references["url_to_info"][url] = StormInformation.from_dict( + references["url_to_info"][url] + ) article.reference = references return article diff --git a/knowledge_storm/utils.py b/knowledge_storm/utils.py index 5cf6f457..d07d067c 100644 --- a/knowledge_storm/utils.py +++ b/knowledge_storm/utils.py @@ -17,7 +17,7 @@ def load_api_key(toml_file_path): try: - with open(toml_file_path, 'r') as file: + with open(toml_file_path, "r") as file: data = toml.load(file) except FileNotFoundError: print(f"File not found: {toml_file_path}", file=sys.stderr) @@ -53,19 +53,19 @@ def limit_word_count_preserve_newline(input_string, max_word_count): """ word_count = 0 - limited_string = '' + limited_string = "" - for word in input_string.split('\n'): + for word in input_string.split("\n"): line_words = word.split() for lw in line_words: if word_count < max_word_count: - limited_string += lw + ' ' + limited_string += lw + " " word_count += 1 else: break if word_count >= max_word_count: break - limited_string = limited_string.strip() + '\n' + limited_string = limited_string.strip() + "\n" return limited_string.strip() @@ -83,7 +83,7 @@ def remove_citations(s): str: The string with all citation patterns removed. """ - return re.sub(r'\[\d+(?:,\s*\d+)*\]', '', s) + return re.sub(r"\[\d+(?:,\s*\d+)*\]", "", s) @staticmethod def parse_citation_indices(s): @@ -96,7 +96,7 @@ def parse_citation_indices(s): Returns: List[int]: A list of unique citation indexes extracted from the content, in the order they appear. """ - matches = re.findall(r'\[\d+\]', s) + matches = re.findall(r"\[\d+\]", s) return [int(index[1:-1]) for index in matches] @staticmethod @@ -117,19 +117,21 @@ def remove_uncompleted_sentences_with_citations(text): # Convert citations like [1, 2, 3] to [1][2][3]. def replace_with_individual_brackets(match): - numbers = match.group(1).split(', ') - return ' '.join(f'[{n}]' for n in numbers) + numbers = match.group(1).split(", ") + return " ".join(f"[{n}]" for n in numbers) # Deduplicate and sort individual groups of citations. def deduplicate_group(match): citations = match.group(0) - unique_citations = list(set(re.findall(r'\[\d+\]', citations))) - sorted_citations = sorted(unique_citations, key=lambda x: int(x.strip('[]'))) + unique_citations = list(set(re.findall(r"\[\d+\]", citations))) + sorted_citations = sorted( + unique_citations, key=lambda x: int(x.strip("[]")) + ) # Return the sorted unique citations as a string - return ''.join(sorted_citations) + return "".join(sorted_citations) - text = re.sub(r'\[([0-9, ]+)\]', replace_with_individual_brackets, text) - text = re.sub(r'(\[\d+\])+', deduplicate_group, text) + text = re.sub(r"\[([0-9, ]+)\]", replace_with_individual_brackets, text) + text = re.sub(r"(\[\d+\])+", deduplicate_group, text) # Deprecated: Remove sentence without proper ending punctuation and citations. # Split the text into sentences (including citations). @@ -150,29 +152,38 @@ def deduplicate_group(match): # combined_sentences += ' '.join(trailing_citations) # Regex pattern to match sentence endings, including optional citation markers. - eos_pattern = r'([.!?])\s*(\[\d+\])?\s*' + eos_pattern = r"([.!?])\s*(\[\d+\])?\s*" matches = list(re.finditer(eos_pattern, text)) if matches: last_match = matches[-1] - text = text[:last_match.end()].strip() + text = text[: last_match.end()].strip() return text @staticmethod def clean_up_citation(conv): for turn in conv.dlg_history: - turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('References:')] - turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('Sources:')] - turn.agent_utterance = turn.agent_utterance.replace('Answer:', '').strip() + turn.agent_utterance = turn.agent_utterance[ + : turn.agent_utterance.find("References:") + ] + turn.agent_utterance = turn.agent_utterance[ + : turn.agent_utterance.find("Sources:") + ] + turn.agent_utterance = turn.agent_utterance.replace("Answer:", "").strip() try: - max_ref_num = max([int(x) for x in re.findall(r'\[(\d+)\]', turn.agent_utterance)]) + max_ref_num = max( + [int(x) for x in re.findall(r"\[(\d+)\]", turn.agent_utterance)] + ) except Exception as e: max_ref_num = 0 if max_ref_num > len(turn.search_results): for i in range(len(turn.search_results), max_ref_num + 1): - turn.agent_utterance = turn.agent_utterance.replace(f'[{i}]', '') - turn.agent_utterance = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( - turn.agent_utterance) + turn.agent_utterance = turn.agent_utterance.replace(f"[{i}]", "") + turn.agent_utterance = ( + ArticleTextProcessing.remove_uncompleted_sentences_with_citations( + turn.agent_utterance + ) + ) return conv @@ -181,36 +192,46 @@ def clean_up_outline(outline, topic=""): output_lines = [] current_level = 0 # To track the current section level - for line in outline.split('\n'): + for line in outline.split("\n"): stripped_line = line.strip() if topic != "" and f"# {topic.lower()}" in stripped_line.lower(): output_lines = [] # Check if the line is a section header - if stripped_line.startswith('#'): - current_level = stripped_line.count('#') + if stripped_line.startswith("#"): + current_level = stripped_line.count("#") output_lines.append(stripped_line) # Check if the line is a bullet point - elif stripped_line.startswith('-'): - subsection_header = '#' * (current_level + 1) + ' ' + stripped_line[1:].strip() + elif stripped_line.startswith("-"): + subsection_header = ( + "#" * (current_level + 1) + " " + stripped_line[1:].strip() + ) output_lines.append(subsection_header) - outline = '\n'.join(output_lines) + outline = "\n".join(output_lines) # Remove references. - outline = re.sub(r"#[#]? See also.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? See Also.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Notes.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? References.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? External links.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? External Links.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Further reading*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Further Reading*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Summary.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? See also.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? See Also.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Notes.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? References.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub( + r"#[#]? External links.*?(?=##|$)", "", outline, flags=re.DOTALL + ) + outline = re.sub( + r"#[#]? External Links.*?(?=##|$)", "", outline, flags=re.DOTALL + ) + outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub( + r"#[#]? Further reading*?(?=##|$)", "", outline, flags=re.DOTALL + ) + outline = re.sub( + r"#[#]? Further Reading*?(?=##|$)", "", outline, flags=re.DOTALL + ) + outline = re.sub(r"#[#]? Summary.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", "", outline, flags=re.DOTALL) return outline @@ -221,34 +242,40 @@ def clean_up_section(text): 2. Deduplicate individual groups of citations. 3. Remove unnecessary summary.""" - paragraphs = text.split('\n') + paragraphs = text.split("\n") output_paragraphs = [] summary_sec_flag = False for p in paragraphs: p = p.strip() if len(p) == 0: continue - if not p.startswith('#'): + if not p.startswith("#"): p = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(p) if summary_sec_flag: - if p.startswith('#'): + if p.startswith("#"): summary_sec_flag = False else: continue - if p.startswith('Overall') or p.startswith('In summary') or p.startswith('In conclusion'): + if ( + p.startswith("Overall") + or p.startswith("In summary") + or p.startswith("In conclusion") + ): continue - if "# Summary" in p or '# Conclusion' in p: + if "# Summary" in p or "# Conclusion" in p: summary_sec_flag = True continue output_paragraphs.append(p) - return '\n\n'.join(output_paragraphs) # Join with '\n\n' for markdown format. + return "\n\n".join(output_paragraphs) # Join with '\n\n' for markdown format. @staticmethod def update_citation_index(s, citation_map): """Update citation index in the string based on the citation map.""" for original_citation in citation_map: - s = s.replace(f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__") + s = s.replace( + f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__" + ) for original_citation, unify_citation in citation_map.items(): s = s.replace(f"__PLACEHOLDER_{original_citation}__", f"[{unify_citation}]") @@ -275,34 +302,34 @@ def parse_article_into_dict(input_string): A dictionary representing contains the section title as the key, and another dictionary as the value, which includes the 'content' and 'subsections' keys as described above. """ - lines = input_string.split('\n') + lines = input_string.split("\n") lines = [line for line in lines if line.strip()] - root = {'content': '', 'subsections': {}} + root = {"content": "", "subsections": {}} current_path = [(root, -1)] # (current_dict, level) for line in lines: - if line.startswith('#'): - level = line.count('#') - title = line.strip('# ').strip() - new_section = {'content': '', 'subsections': {}} + if line.startswith("#"): + level = line.count("#") + title = line.strip("# ").strip() + new_section = {"content": "", "subsections": {}} # Pop from stack until find the parent level while current_path and current_path[-1][1] >= level: current_path.pop() # Append new section to the nearest upper level's subsections - current_path[-1][0]['subsections'][title] = new_section + current_path[-1][0]["subsections"][title] = new_section current_path.append((new_section, level)) else: - current_path[-1][0]['content'] += line + '\n' + current_path[-1][0]["content"] += line + "\n" - return root['subsections'] + return root["subsections"] class FileIOHelper: @staticmethod def dump_json(obj, file_name, encoding="utf-8"): - with open(file_name, 'w', encoding=encoding) as fw: + with open(file_name, "w", encoding=encoding) as fw: json.dump(obj, fw, default=FileIOHelper.handle_non_serializable) @staticmethod @@ -311,27 +338,27 @@ def handle_non_serializable(obj): @staticmethod def load_json(file_name, encoding="utf-8"): - with open(file_name, 'r', encoding=encoding) as fr: + with open(file_name, "r", encoding=encoding) as fr: return json.load(fr) @staticmethod def write_str(s, path): - with open(path, 'w') as f: + with open(path, "w") as f: f.write(s) @staticmethod def load_str(path): - with open(path, 'r') as f: - return '\n'.join(f.readlines()) + with open(path, "r") as f: + return "\n".join(f.readlines()) @staticmethod def dump_pickle(obj, path): - with open(path, 'wb') as f: + with open(path, "wb") as f: pickle.dump(obj, f) @staticmethod def load_pickle(path): - with open(path, 'rb') as f: + with open(path, "rb") as f: return pickle.load(f) @@ -341,7 +368,12 @@ class WebPageHelper: Acknowledgement: Part of the code is adapted from https://github.com/stanford-oval/WikiChat project. """ - def __init__(self, min_char_count: int = 150, snippet_chunk_size: int = 1000, max_thread_num: int = 10): + def __init__( + self, + min_char_count: int = 150, + snippet_chunk_size: int = 1000, + max_thread_num: int = 10, + ): """ Args: min_char_count: Minimum character count for the article to be considered valid. @@ -382,7 +414,9 @@ def download_webpage(self, url: str): return None def urls_to_articles(self, urls: List[str]) -> Dict: - with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_num) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=self.max_thread_num + ) as executor: htmls = list(executor.map(self.download_webpage, urls)) articles = {} diff --git a/requirements.txt b/requirements.txt index 8ac1b95b..79ee9d68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ trafilatura langchain-huggingface qdrant-client langchain-qdrant +numpy==1.26.4 From 85a1e6e4391059635109304c1f51fb0c486fe6bc Mon Sep 17 00:00:00 2001 From: zenith110 Date: Sun, 28 Jul 2024 10:32:14 -0400 Subject: [PATCH 3/7] Reverted back everything except requirements.txt, the new example file, and rm.py --- knowledge_storm/__init__.py | 2 +- knowledge_storm/interface.py | 69 ++-- knowledge_storm/lm.py | 228 +++++-------- knowledge_storm/storm_wiki/engine.py | 310 +++++++----------- .../storm_wiki/modules/article_generation.py | 117 +++---- .../storm_wiki/modules/article_polish.py | 46 +-- .../storm_wiki/modules/knowledge_curation.py | 249 +++++--------- .../storm_wiki/modules/outline_generation.py | 106 +++--- .../storm_wiki/modules/persona_generator.py | 64 ++-- .../storm_wiki/modules/retriever.py | 23 +- .../storm_wiki/modules/storm_dataclass.py | 233 +++++-------- knowledge_storm/utils.py | 166 ++++------ 12 files changed, 589 insertions(+), 1024 deletions(-) diff --git a/knowledge_storm/__init__.py b/knowledge_storm/__init__.py index 74dcabbe..f1fd18ea 100644 --- a/knowledge_storm/__init__.py +++ b/knowledge_storm/__init__.py @@ -1,5 +1,5 @@ from .storm_wiki.engine import ( STORMWikiLMConfigs, STORMWikiRunnerArguments, - STORMWikiRunner, + STORMWikiRunner ) diff --git a/knowledge_storm/interface.py b/knowledge_storm/interface.py index f6c11bd9..03df2fb6 100644 --- a/knowledge_storm/interface.py +++ b/knowledge_storm/interface.py @@ -5,9 +5,7 @@ from collections import OrderedDict from typing import Dict, List, Optional, Union -logging.basicConfig( - level=logging.INFO, format="%(name)s : %(levelname)-8s : %(message)s" -) +logging.basicConfig(level=logging.INFO, format='%(name)s : %(levelname)-8s : %(message)s') logger = logging.getLogger(__name__) @@ -72,9 +70,7 @@ class Article(ABC): def __init__(self, topic_name): self.root = ArticleSectionNode(topic_name) - def find_section( - self, node: ArticleSectionNode, name: str - ) -> Optional[ArticleSectionNode]: + def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]: """ Return the node of the section given the section name. @@ -156,9 +152,7 @@ def prune_empty_nodes(self, node=None): if node is None: node = self.root - node.children[:] = [ - child for child in node.children if self.prune_empty_nodes(child) - ] + node.children[:] = [child for child in node.children if self.prune_empty_nodes(child)] if (node.content is None or node.content == "") and not node.children: return None @@ -184,9 +178,7 @@ def update_search_top_k(self, k): def collect_and_reset_rm_usage(self): combined_usage = [] for attr_name in self.__dict__: - if "_rm" in attr_name and hasattr( - getattr(self, attr_name), "get_usage_and_reset" - ): + if '_rm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'): combined_usage.append(getattr(self, attr_name).get_usage_and_reset()) name_to_usage = {} @@ -248,9 +240,7 @@ class OutlineGenerationModule(ABC): """ @abstractmethod - def generate_outline( - self, topic: str, information_table: InformationTable, **kwargs - ) -> Article: + def generate_outline(self, topic: str, information_table: InformationTable, **kwargs) -> Article: """ Generate outline for the article. Required arguments include: topic: the topic of interest @@ -273,13 +263,11 @@ class ArticleGenerationModule(ABC): """ @abstractmethod - def generate_article( - self, - topic: str, - information_table: InformationTable, - article_with_outline: Article, - **kwargs, - ) -> Article: + def generate_article(self, + topic: str, + information_table: InformationTable, + article_with_outline: Article, + **kwargs) -> Article: """ Generate article. Required arguments include: topic: the topic of interest @@ -324,15 +312,14 @@ def wrapper(self, *args, **kwargs): class LMConfigs(ABC): """Abstract base class for language model configurations of the knowledge curation engine. - The language model used for each part should be declared with a suffix '_lm' in the attribute name. - """ + The language model used for each part should be declared with a suffix '_lm' in the attribute name.""" def __init__(self): pass def init_check(self): for attr_name in self.__dict__: - if "_lm" in attr_name and getattr(self, attr_name) is None: + if '_lm' in attr_name and getattr(self, attr_name) is None: logging.warning( f"Language model for {attr_name} is not initialized. Please call set_{attr_name}()" ) @@ -340,7 +327,7 @@ def init_check(self): def collect_and_reset_lm_history(self): history = [] for attr_name in self.__dict__: - if "_lm" in attr_name and hasattr(getattr(self, attr_name), "history"): + if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'history'): history.extend(getattr(self, attr_name).history) getattr(self, attr_name).history = [] @@ -349,9 +336,7 @@ def collect_and_reset_lm_history(self): def collect_and_reset_lm_usage(self): combined_usage = [] for attr_name in self.__dict__: - if "_lm" in attr_name and hasattr( - getattr(self, attr_name), "get_usage_and_reset" - ): + if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'): combined_usage.append(getattr(self, attr_name).get_usage_and_reset()) model_name_to_usage = {} @@ -360,12 +345,8 @@ def collect_and_reset_lm_usage(self): if model_name not in model_name_to_usage: model_name_to_usage[model_name] = tokens else: - model_name_to_usage[model_name]["prompt_tokens"] += tokens[ - "prompt_tokens" - ] - model_name_to_usage[model_name]["completion_tokens"] += tokens[ - "completion_tokens" - ] + model_name_to_usage[model_name]['prompt_tokens'] += tokens['prompt_tokens'] + model_name_to_usage[model_name]['completion_tokens'] += tokens['completion_tokens'] return model_name_to_usage @@ -373,9 +354,8 @@ def log(self): return OrderedDict( { - attr_name: getattr(self, attr_name).kwargs - for attr_name in self.__dict__ - if "_lm" in attr_name and hasattr(getattr(self, attr_name), "kwargs") + attr_name: getattr(self, attr_name).kwargs for attr_name in self.__dict__ if + '_lm' in attr_name and hasattr(getattr(self, attr_name), 'kwargs') } ) @@ -399,21 +379,16 @@ def wrapper(*args, **kwargs): self.time[func.__name__] = execution_time logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds") self.lm_cost[func.__name__] = self.lm_configs.collect_and_reset_lm_usage() - if hasattr(self, "retriever"): - self.rm_cost[func.__name__] = ( - self.retriever.collect_and_reset_rm_usage() - ) + if hasattr(self, 'retriever'): + self.rm_cost[func.__name__] = self.retriever.collect_and_reset_rm_usage() return result return wrapper def apply_decorators(self): """Apply decorators to methods that need them.""" - methods_to_decorate = [ - method_name - for method_name in dir(self) - if callable(getattr(self, method_name)) and method_name.startswith("run_") - ] + methods_to_decorate = [method_name for method_name in dir(self) + if callable(getattr(self, method_name)) and method_name.startswith('run_')] for method_name in methods_to_decorate: original_method = getattr(self, method_name) decorated_method = self.log_execution_time_and_lm_rm_usage(original_method) diff --git a/knowledge_storm/lm.py b/knowledge_storm/lm.py index 1aa34d24..e9c50852 100644 --- a/knowledge_storm/lm.py +++ b/knowledge_storm/lm.py @@ -9,10 +9,7 @@ import requests from dsp import ERRORS, backoff_hdlr, giveup_hdlr from dsp.modules.hf import openai_to_hf -from dsp.modules.hf_client import ( - send_hfvllm_request_v00, - send_hftgi_request_v01_wrapped, -) +from dsp.modules.hf_client import send_hfvllm_request_v00, send_hftgi_request_v01_wrapped from transformers import AutoTokenizer try: @@ -25,11 +22,11 @@ class OpenAIModel(dspy.OpenAI): """A wrapper class for dspy.OpenAI.""" def __init__( - self, - model: str = "gpt-3.5-turbo-instruct", - api_key: Optional[str] = None, - model_type: Literal["chat", "text"] = None, - **kwargs, + self, + model: str = "gpt-3.5-turbo-instruct", + api_key: Optional[str] = None, + model_type: Literal["chat", "text"] = None, + **kwargs ): super().__init__(model=model, api_key=api_key, model_type=model_type, **kwargs) self._token_usage_lock = threading.Lock() @@ -38,20 +35,17 @@ def __init__( def log_usage(self, response): """Log the total tokens from the OpenAI API response.""" - usage_data = response.get("usage") + usage_data = response.get('usage') if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get("prompt_tokens", 0) - self.completion_tokens += usage_data.get("completion_tokens", 0) + self.prompt_tokens += usage_data.get('prompt_tokens', 0) + self.completion_tokens += usage_data.get('completion_tokens', 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.kwargs.get("model") - or self.kwargs.get("engine"): { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - } + self.kwargs.get('model') or self.kwargs.get('engine'): + {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -59,11 +53,11 @@ def get_usage_and_reset(self): return usage def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, ) -> list[dict[str, Any]]: """Copied from dspy/dsp/modules/gpt3.py with the addition of tracking token usage.""" @@ -115,11 +109,11 @@ class DeepSeekModel(dspy.OpenAI): """A wrapper class for DeepSeek API, compatible with dspy.OpenAI.""" def __init__( - self, - model: str = "deepseek-chat", - api_key: Optional[str] = None, - api_base: str = "https://api.deepseek.com", - **kwargs, + self, + model: str = "deepseek-chat", + api_key: Optional[str] = None, + api_base: str = "https://api.deepseek.com", + **kwargs ): super().__init__(model=model, api_key=api_key, api_base=api_base, **kwargs) self._token_usage_lock = threading.Lock() @@ -129,25 +123,21 @@ def __init__( self.api_key = api_key or os.getenv("DEEPSEEK_API_KEY") self.api_base = api_base if not self.api_key: - raise ValueError( - "DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY" - ) + raise ValueError("DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY") def log_usage(self, response): """Log the total tokens from the DeepSeek API response.""" - usage_data = response.get("usage") + usage_data = response.get('usage') if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get("prompt_tokens", 0) - self.completion_tokens += usage_data.get("completion_tokens", 0) + self.prompt_tokens += usage_data.get('prompt_tokens', 0) + self.completion_tokens += usage_data.get('completion_tokens', 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - } + self.model: + {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -164,25 +154,23 @@ def _create_completion(self, prompt: str, **kwargs): """Create a completion using the DeepSeek API.""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", + "Authorization": f"Bearer {self.api_key}" } data = { "model": self.model, "messages": [{"role": "user", "content": prompt}], - **kwargs, + **kwargs } - response = requests.post( - f"{self.api_base}/v1/chat/completions", headers=headers, json=data - ) + response = requests.post(f"{self.api_base}/v1/chat/completions", headers=headers, json=data) response.raise_for_status() return response.json() def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, ) -> list[dict[str, Any]]: """Call the DeepSeek API to generate completions.""" assert only_completed, "for now" @@ -208,46 +196,35 @@ def __call__( class AzureOpenAIModel(dspy.AzureOpenAI): """A wrapper class for dspy.AzureOpenAI.""" - def __init__( - self, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - model: str = "gpt-3.5-turbo-instruct", - api_key: Optional[str] = None, - model_type: Literal["chat", "text"] = "chat", - **kwargs, + self, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + model: str = "gpt-3.5-turbo-instruct", + api_key: Optional[str] = None, + model_type: Literal["chat", "text"] = "chat", + **kwargs, ): super().__init__( - api_base=api_base, - api_version=api_version, - model=model, - api_key=api_key, - model_type=model_type, - **kwargs, - ) + api_base=api_base, api_version=api_version, model=model, api_key=api_key, model_type=model_type, **kwargs) self._token_usage_lock = threading.Lock() self.prompt_tokens = 0 self.completion_tokens = 0 def log_usage(self, response): """Log the total tokens from the OpenAI API response. - Override log_usage() in dspy.AzureOpenAI for tracking accumulated token usage. - """ - usage_data = response.get("usage") + Override log_usage() in dspy.AzureOpenAI for tracking accumulated token usage.""" + usage_data = response.get('usage') if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get("prompt_tokens", 0) - self.completion_tokens += usage_data.get("completion_tokens", 0) + self.prompt_tokens += usage_data.get('prompt_tokens', 0) + self.completion_tokens += usage_data.get('completion_tokens', 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.kwargs.get("model") - or self.kwargs.get("engine"): { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - } + self.kwargs.get('model') or self.kwargs.get('engine'): + {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -259,11 +236,11 @@ class ClaudeModel(dspy.dsp.modules.lm.LM): """Copied from dspy/dsp/modules/anthropic.py with the addition of tracking token usage.""" def __init__( - self, - model: str, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - **kwargs, + self, + model: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + **kwargs, ): super().__init__(model) try: @@ -272,21 +249,12 @@ def __init__( raise ImportError("Claude requires `pip install anthropic`.") from err self.provider = "anthropic" - self.api_key = api_key = ( - os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key - ) - self.api_base = ( - "https://api.anthropic.com/v1/messages" if api_base is None else api_base - ) - self.kwargs = { - "temperature": kwargs.get("temperature", 0.0), - "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), - "top_p": kwargs.get("top_p", 1.0), - "top_k": kwargs.get("top_k", 1), - "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), - **kwargs, - "model": model, - } + self.api_key = api_key = os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key + self.api_base = "https://api.anthropic.com/v1/messages" if api_base is None else api_base + self.kwargs = {"temperature": kwargs.get("temperature", 0.0), + "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), "top_p": kwargs.get("top_p", 1.0), + "top_k": kwargs.get("top_k", 1), "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), + **kwargs, "model": model} self.history: list[dict[str, Any]] = [] self.client = Anthropic(api_key=api_key) self.model = model @@ -306,10 +274,8 @@ def log_usage(self, response): def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - } + self.model: + {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -341,7 +307,7 @@ def basic_request(self, prompt: str, **kwargs): "usage": { "input_tokens": response.usage.input_tokens, "output_tokens": response.usage.output_tokens, - }, + } }, "kwargs": kwargs, "raw_kwargs": raw_kwargs, @@ -411,7 +377,10 @@ def _generate(self, prompt, **kwargs): # "max_tokens": kwargs["max_tokens"], # "temperature": kwargs["temperature"], # } - payload = {"prompt": prompt, **kwargs} + payload = { + "prompt": prompt, + **kwargs + } response = send_hfvllm_request_v00( f"{self.url}/v1/completions", @@ -444,17 +413,11 @@ def __init__(self, model, port, url="http://localhost", **kwargs): super().__init__(model=model, base_url=f"{url}:{port}", **kwargs) # Store additional kwargs for the generate method. self.kwargs = {**self.kwargs, **kwargs} - + class TGIClient(dspy.HFClientTGI): def __init__(self, model, port, url, http_request_kwargs=None, **kwargs): - super().__init__( - model=model, - port=port, - url=url, - http_request_kwargs=http_request_kwargs, - **kwargs, - ) + super().__init__(model=model, port=port, url=url, http_request_kwargs=http_request_kwargs, **kwargs) def _generate(self, prompt, **kwargs): """Copied from dspy/dsp/modules/hf_client.py with the addition of removing hard-coded parameters.""" @@ -493,8 +456,8 @@ def _generate(self, prompt, **kwargs): completions = [json_response["generated_text"]] if ( - "details" in json_response - and "best_of_sequences" in json_response["details"] + "details" in json_response + and "best_of_sequences" in json_response["details"] ): completions += [ x["generated_text"] @@ -511,22 +474,13 @@ def _generate(self, prompt, **kwargs): class TogetherClient(dspy.HFModel): """A wrapper class for dspy.Together.""" - def __init__( - self, - model, - apply_tokenizer_chat_template=False, - hf_tokenizer_name=None, - **kwargs, - ): + def __init__(self, model, apply_tokenizer_chat_template=False, hf_tokenizer_name=None, **kwargs): """Copied from dspy/dsp/modules/hf_client.py with the support of applying tokenizer chat template.""" super().__init__(model=model, is_client=True) self.session = requests.Session() - self.api_base = ( - "https://api.together.xyz/v1/completions" - if os.getenv("TOGETHER_API_BASE") is None - else os.getenv("TOGETHER_API_BASE") - ) + self.api_base = "https://api.together.xyz/v1/completions" if os.getenv( + "TOGETHER_API_BASE") is None else os.getenv("TOGETHER_API_BASE") self.token = os.getenv("TOGETHER_API_KEY") self.model = model @@ -538,9 +492,7 @@ def __init__( logging.info("Loading huggingface tokenizer.") if hf_tokenizer_name is None: hf_tokenizer_name = self.model - self.tokenizer = AutoTokenizer.from_pretrained( - hf_tokenizer_name, cache_dir=kwargs.get("cache_dir", None) - ) + self.tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_name, cache_dir=kwargs.get("cache_dir", None)) stop_default = "\n\n---" @@ -560,19 +512,17 @@ def __init__( def log_usage(self, response): """Log the total tokens from the OpenAI API response.""" - usage_data = response.get("usage") + usage_data = response.get('usage') if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get("prompt_tokens", 0) - self.completion_tokens += usage_data.get("completion_tokens", 0) + self.prompt_tokens += usage_data.get('prompt_tokens', 0) + self.completion_tokens += usage_data.get('completion_tokens', 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - } + self.model: + {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -597,18 +547,14 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): top_k = kwargs.get("top_k", 50) repetition_penalty = kwargs.get("repetition_penalty", 1) if self.apply_tokenizer_chat_template: - prompt = self.tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], tokenize=False - ) + prompt = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False) # prompt = f"[INST]{prompt}[/INST]" if self.use_inst_template else prompt if use_chat_api: url = f"{self.api_base}/chat/completions" messages = [ - { - "role": "system", - "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections.", - }, + {"role": "system", + "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections."}, {"role": "user", "content": prompt}, ] body = { @@ -641,13 +587,9 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): self.log_usage(resp_json) if use_chat_api: # completions = [resp_json['output'].get('choices', [])[0].get('message', {}).get('content', "")] - completions = [ - resp_json.get("choices", [])[0] - .get("message", {}) - .get("content", "") - ] + completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")] else: # completions = [resp_json['output'].get('choices', [])[0].get('text', "")] - completions = [resp_json.get("choices", [])[0].get("text", "")] + completions = [resp_json.get('choices', [])[0].get('text', "")] response = {"prompt": prompt, "choices": [{"text": c} for c in completions]} return response diff --git a/knowledge_storm/storm_wiki/engine.py b/knowledge_storm/storm_wiki/engine.py index 746a07b0..e0c8dfcc 100644 --- a/knowledge_storm/storm_wiki/engine.py +++ b/knowledge_storm/storm_wiki/engine.py @@ -28,52 +28,43 @@ class STORMWikiLMConfigs(LMConfigs): """ def __init__(self): - self.conv_simulator_lm = ( - None # LLM used in conversation simulator except for question asking. - ) + self.conv_simulator_lm = None # LLM used in conversation simulator except for question asking. self.question_asker_lm = None # LLM used in question asking. self.outline_gen_lm = None # LLM used in outline generation. self.article_gen_lm = None # LLM used in article generation. self.article_polish_lm = None # LLM used in article polishing. def init_openai_model( - self, - openai_api_key: str, - openai_type: Literal["openai", "azure"], - api_base: Optional[str] = None, - api_version: Optional[str] = None, - temperature: Optional[float] = 1.0, - top_p: Optional[float] = 0.9, + self, + openai_api_key: str, + openai_type: Literal["openai", "azure"], + api_base: Optional[str] = None, + api_version: Optional[str] = None, + temperature: Optional[float] = 1.0, + top_p: Optional[float] = 0.9 ): """Legacy: Corresponding to the original setup in the NAACL'24 paper.""" openai_kwargs = { - "api_key": openai_api_key, - "api_provider": openai_type, - "temperature": temperature, - "top_p": top_p, - "api_base": None, + 'api_key': openai_api_key, + 'api_provider': openai_type, + 'temperature': temperature, + 'top_p': top_p, + 'api_base': None } - if openai_type and openai_type == "openai": - self.conv_simulator_lm = OpenAIModel( - model="gpt-3.5-turbo-instruct", max_tokens=500, **openai_kwargs - ) - self.question_asker_lm = OpenAIModel( - model="gpt-3.5-turbo", max_tokens=500, **openai_kwargs - ) + if openai_type and openai_type == 'openai': + self.conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo-instruct', + max_tokens=500, **openai_kwargs) + self.question_asker_lm = OpenAIModel(model='gpt-3.5-turbo', + max_tokens=500, **openai_kwargs) # 1/12/2024: Update gpt-4 to gpt-4-1106-preview. (Currently keep the original setup when using azure.) - self.outline_gen_lm = OpenAIModel( - model="gpt-4-0125-preview", max_tokens=400, **openai_kwargs - ) - self.article_gen_lm = OpenAIModel( - model="gpt-4o-2024-05-13", max_tokens=700, **openai_kwargs - ) - self.article_polish_lm = OpenAIModel( - model="gpt-4o-2024-05-13", max_tokens=4000, **openai_kwargs - ) + self.outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview', + max_tokens=400, **openai_kwargs) + self.article_gen_lm = OpenAIModel(model='gpt-4o-2024-05-13', + max_tokens=700, **openai_kwargs) + self.article_polish_lm = OpenAIModel(model='gpt-4o-2024-05-13', + max_tokens=4000, **openai_kwargs) else: - logging.warning( - "No valid OpenAI API provider is provided. Cannot use default LLM configurations." - ) + logging.warning('No valid OpenAI API provider is provided. Cannot use default LLM configurations.') def set_conv_simulator_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.conv_simulator_lm = model @@ -94,21 +85,16 @@ def set_article_polish_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): @dataclass class STORMWikiRunnerArguments: """Arguments for controlling the STORM Wiki pipeline.""" - output_dir: str = field( metadata={"help": "Output directory for the results."}, ) max_conv_turn: int = field( default=3, - metadata={ - "help": "Maximum number of questions in conversational question asking." - }, + metadata={"help": "Maximum number of questions in conversational question asking."}, ) max_perspective: int = field( default=3, - metadata={ - "help": "Maximum number of perspectives to consider in perspective-guided question asking." - }, + metadata={"help": "Maximum number of perspectives to consider in perspective-guided question asking."}, ) max_search_queries_per_turn: int = field( default=3, @@ -128,27 +114,24 @@ class STORMWikiRunnerArguments: ) max_thread_num: int = field( default=10, - metadata={ - "help": "Maximum number of threads to use. " - "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API." - }, + metadata={"help": "Maximum number of threads to use. " + "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API."}, ) class STORMWikiRunner(Engine): """STORM Wiki pipeline runner.""" - def __init__( - self, args: STORMWikiRunnerArguments, lm_configs: STORMWikiLMConfigs, rm - ): + def __init__(self, + args: STORMWikiRunnerArguments, + lm_configs: STORMWikiLMConfigs, + rm): super().__init__(lm_configs=lm_configs) self.args = args self.lm_configs = lm_configs self.retriever = StormRetriever(rm=rm, k=self.args.retrieve_top_k) - storm_persona_generator = StormPersonaGenerator( - self.lm_configs.question_asker_lm - ) + storm_persona_generator = StormPersonaGenerator(self.lm_configs.question_asker_lm) self.storm_knowledge_curation_module = StormKnowledgeCurationModule( retriever=self.retriever, persona_generator=storm_persona_generator, @@ -157,7 +140,7 @@ def __init__( max_search_queries_per_turn=self.args.max_search_queries_per_turn, search_top_k=self.args.search_top_k, max_conv_turn=self.args.max_conv_turn, - max_thread_num=self.args.max_thread_num, + max_thread_num=self.args.max_thread_num ) self.storm_outline_generation_module = StormOutlineGenerationModule( outline_gen_lm=self.lm_configs.outline_gen_lm @@ -165,96 +148,73 @@ def __init__( self.storm_article_generation = StormArticleGenerationModule( article_gen_lm=self.lm_configs.article_gen_lm, retrieve_top_k=self.args.retrieve_top_k, - max_thread_num=self.args.max_thread_num, + max_thread_num=self.args.max_thread_num ) self.storm_article_polishing_module = StormArticlePolishingModule( article_gen_lm=self.lm_configs.article_gen_lm, - article_polish_lm=self.lm_configs.article_polish_lm, + article_polish_lm=self.lm_configs.article_polish_lm ) self.lm_configs.init_check() self.apply_decorators() - def run_knowledge_curation_module( - self, - ground_truth_url: str = "None", - callback_handler: BaseCallbackHandler = None, - ) -> StormInformationTable: - - information_table, conversation_log = ( - self.storm_knowledge_curation_module.research( - topic=self.topic, - ground_truth_url=ground_truth_url, - callback_handler=callback_handler, - max_perspective=self.args.max_perspective, - disable_perspective=False, - return_conversation_log=True, - ) - ) + def run_knowledge_curation_module(self, + ground_truth_url: str = "None", + callback_handler: BaseCallbackHandler = None) -> StormInformationTable: - FileIOHelper.dump_json( - conversation_log, - os.path.join(self.article_output_dir, "conversation_log.json"), - ) - information_table.dump_url_to_info( - os.path.join(self.article_output_dir, "raw_search_results.json") + information_table, conversation_log = self.storm_knowledge_curation_module.research( + topic=self.topic, + ground_truth_url=ground_truth_url, + callback_handler=callback_handler, + max_perspective=self.args.max_perspective, + disable_perspective=False, + return_conversation_log=True ) + + FileIOHelper.dump_json(conversation_log, os.path.join(self.article_output_dir, 'conversation_log.json')) + information_table.dump_url_to_info(os.path.join(self.article_output_dir, 'raw_search_results.json')) return information_table - def run_outline_generation_module( - self, - information_table: StormInformationTable, - callback_handler: BaseCallbackHandler = None, - ) -> StormArticle: + def run_outline_generation_module(self, + information_table: StormInformationTable, + callback_handler: BaseCallbackHandler = None) -> StormArticle: outline, draft_outline = self.storm_outline_generation_module.generate_outline( topic=self.topic, information_table=information_table, return_draft_outline=True, - callback_handler=callback_handler, - ) - outline.dump_outline_to_file( - os.path.join(self.article_output_dir, "storm_gen_outline.txt") - ) - draft_outline.dump_outline_to_file( - os.path.join(self.article_output_dir, "direct_gen_outline.txt") + callback_handler=callback_handler ) + outline.dump_outline_to_file(os.path.join(self.article_output_dir, 'storm_gen_outline.txt')) + draft_outline.dump_outline_to_file(os.path.join(self.article_output_dir, "direct_gen_outline.txt")) return outline - def run_article_generation_module( - self, - outline: StormArticle, - information_table=StormInformationTable, - callback_handler: BaseCallbackHandler = None, - ) -> StormArticle: + def run_article_generation_module(self, + outline: StormArticle, + information_table=StormInformationTable, + callback_handler: BaseCallbackHandler = None) -> StormArticle: draft_article = self.storm_article_generation.generate_article( topic=self.topic, information_table=information_table, article_with_outline=outline, - callback_handler=callback_handler, - ) - draft_article.dump_article_as_plain_text( - os.path.join(self.article_output_dir, "storm_gen_article.txt") - ) - draft_article.dump_reference_to_file( - os.path.join(self.article_output_dir, "url_to_info.json") + callback_handler=callback_handler ) + draft_article.dump_article_as_plain_text(os.path.join(self.article_output_dir, 'storm_gen_article.txt')) + draft_article.dump_reference_to_file(os.path.join(self.article_output_dir, 'url_to_info.json')) return draft_article - def run_article_polishing_module( - self, draft_article: StormArticle, remove_duplicate: bool = False - ) -> StormArticle: + def run_article_polishing_module(self, + draft_article: StormArticle, + remove_duplicate: bool = False) -> StormArticle: polished_article = self.storm_article_polishing_module.polish_article( topic=self.topic, draft_article=draft_article, - remove_duplicate=remove_duplicate, - ) - FileIOHelper.write_str( - polished_article.to_string(), - os.path.join(self.article_output_dir, "storm_gen_article_polished.txt"), + remove_duplicate=remove_duplicate ) + FileIOHelper.write_str(polished_article.to_string(), + os.path.join(self.article_output_dir, 'storm_gen_article_polished.txt')) return polished_article def post_run(self): @@ -264,61 +224,43 @@ def post_run(self): 2. Dumping the LLM call history. """ config_log = self.lm_configs.log() - FileIOHelper.dump_json( - config_log, os.path.join(self.article_output_dir, "run_config.json") - ) + FileIOHelper.dump_json(config_log, os.path.join(self.article_output_dir, 'run_config.json')) llm_call_history = self.lm_configs.collect_and_reset_lm_history() - with open( - os.path.join(self.article_output_dir, "llm_call_history.jsonl"), "w" - ) as f: + with open(os.path.join(self.article_output_dir, 'llm_call_history.jsonl'), 'w') as f: for call in llm_call_history: - if "kwargs" in call: - call.pop( - "kwargs" - ) # All kwargs are dumped together to run_config.json. - f.write(json.dumps(call) + "\n") + if 'kwargs' in call: + call.pop('kwargs') # All kwargs are dumped together to run_config.json. + f.write(json.dumps(call) + '\n') def _load_information_table_from_local_fs(self, information_table_local_path): assert os.path.exists(information_table_local_path), makeStringRed( - f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic." - ) - return StormInformationTable.from_conversation_log_file( - information_table_local_path - ) + f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic.") + return StormInformationTable.from_conversation_log_file(information_table_local_path) def _load_outline_from_local_fs(self, topic, outline_local_path): assert os.path.exists(outline_local_path), makeStringRed( - f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic." - ) + f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic.") return StormArticle.from_outline_file(topic=topic, file_path=outline_local_path) - def _load_draft_article_from_local_fs( - self, topic, draft_article_path, url_to_info_path - ): + def _load_draft_article_from_local_fs(self, topic, draft_article_path, url_to_info_path): assert os.path.exists(draft_article_path), makeStringRed( - f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic." - ) + f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic.") assert os.path.exists(url_to_info_path), makeStringRed( - f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic." - ) + f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic.") article_text = FileIOHelper.load_str(draft_article_path) references = FileIOHelper.load_json(url_to_info_path) - return StormArticle.from_string( - topic_name=topic, article_text=article_text, references=references - ) - - def run( - self, - topic: str, - ground_truth_url: str = "", - do_research: bool = True, - do_generate_outline: bool = True, - do_generate_article: bool = True, - do_polish_article: bool = True, - remove_duplicate: bool = False, - callback_handler: BaseCallbackHandler = BaseCallbackHandler(), - ): + return StormArticle.from_string(topic_name=topic, article_text=article_text, references=references) + + def run(self, + topic: str, + ground_truth_url: str = '', + do_research: bool = True, + do_generate_outline: bool = True, + do_generate_article: bool = True, + do_polish_article: bool = True, + remove_duplicate: bool = False, + callback_handler: BaseCallbackHandler = BaseCallbackHandler()): """ Run the STORM pipeline. @@ -336,74 +278,50 @@ def run( remove_duplicate: If True, remove duplicated content. callback_handler: A callback handler to handle the intermediate results. """ - assert ( - do_research - or do_generate_outline - or do_generate_article - or do_polish_article - ), makeStringRed( - "No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article" - ) + assert do_research or do_generate_outline or do_generate_article or do_polish_article, \ + makeStringRed( + "No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article") self.topic = topic - self.article_dir_name = topic.replace(" ", "_").replace("/", "_") - self.article_output_dir = os.path.join( - self.args.output_dir, self.article_dir_name - ) + self.article_dir_name = topic.replace(' ', '_').replace('/', '_') + self.article_output_dir = os.path.join(self.args.output_dir, self.article_dir_name) os.makedirs(self.article_output_dir, exist_ok=True) # research module information_table: StormInformationTable = None if do_research: - information_table = self.run_knowledge_curation_module( - ground_truth_url=ground_truth_url, callback_handler=callback_handler - ) + information_table = self.run_knowledge_curation_module(ground_truth_url=ground_truth_url, + callback_handler=callback_handler) # outline generation module outline: StormArticle = None if do_generate_outline: # load information table if it's not initialized if information_table is None: information_table = self._load_information_table_from_local_fs( - os.path.join(self.article_output_dir, "conversation_log.json") - ) - outline = self.run_outline_generation_module( - information_table=information_table, callback_handler=callback_handler - ) + os.path.join(self.article_output_dir, 'conversation_log.json')) + outline = self.run_outline_generation_module(information_table=information_table, + callback_handler=callback_handler) # article generation module draft_article: StormArticle = None if do_generate_article: if information_table is None: information_table = self._load_information_table_from_local_fs( - os.path.join(self.article_output_dir, "conversation_log.json") - ) + os.path.join(self.article_output_dir, 'conversation_log.json')) if outline is None: - outline = self._load_outline_from_local_fs( - topic=topic, - outline_local_path=os.path.join( - self.article_output_dir, "storm_gen_outline.txt" - ), - ) - draft_article = self.run_article_generation_module( - outline=outline, - information_table=information_table, - callback_handler=callback_handler, - ) + outline = self._load_outline_from_local_fs(topic=topic, + outline_local_path=os.path.join(self.article_output_dir, + 'storm_gen_outline.txt')) + draft_article = self.run_article_generation_module(outline=outline, + information_table=information_table, + callback_handler=callback_handler) # article polishing module if do_polish_article: if draft_article is None: - draft_article_path = os.path.join( - self.article_output_dir, "storm_gen_article.txt" - ) - url_to_info_path = os.path.join( - self.article_output_dir, "url_to_info.json" - ) - draft_article = self._load_draft_article_from_local_fs( - topic=topic, - draft_article_path=draft_article_path, - url_to_info_path=url_to_info_path, - ) - self.run_article_polishing_module( - draft_article=draft_article, remove_duplicate=remove_duplicate - ) + draft_article_path = os.path.join(self.article_output_dir, 'storm_gen_article.txt') + url_to_info_path = os.path.join(self.article_output_dir, 'url_to_info.json') + draft_article = self._load_draft_article_from_local_fs(topic=topic, + draft_article_path=draft_article_path, + url_to_info_path=url_to_info_path) + self.run_article_polishing_module(draft_article=draft_article, remove_duplicate=remove_duplicate) diff --git a/knowledge_storm/storm_wiki/modules/article_generation.py b/knowledge_storm/storm_wiki/modules/article_generation.py index 2e711465..a114b3ec 100644 --- a/knowledge_storm/storm_wiki/modules/article_generation.py +++ b/knowledge_storm/storm_wiki/modules/article_generation.py @@ -15,48 +15,35 @@ class StormArticleGenerationModule(ArticleGenerationModule): """ The interface for article generation stage. Given topic, collected information from - knowledge curation stage, generated outline from outline generation stage, + knowledge curation stage, generated outline from outline generation stage, """ - def __init__( - self, - article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel], - retrieve_top_k: int = 5, - max_thread_num: int = 10, - ): + def __init__(self, + article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel], + retrieve_top_k: int = 5, + max_thread_num: int = 10): super().__init__() self.retrieve_top_k = retrieve_top_k self.article_gen_lm = article_gen_lm self.max_thread_num = max_thread_num self.section_gen = ConvToSection(engine=self.article_gen_lm) - def generate_section( - self, topic, section_name, information_table, section_outline, section_query - ): + def generate_section(self, topic, section_name, information_table, section_outline, section_query): collected_info: List[StormInformation] = [] if information_table is not None: - collected_info = information_table.retrieve_information( - queries=section_query, search_top_k=self.retrieve_top_k - ) - output = self.section_gen( - topic=topic, - outline=section_outline, - section=section_name, - collected_info=collected_info, - ) - return { - "section_name": section_name, - "section_content": output.section, - "collected_info": collected_info, - } - - def generate_article( - self, - topic: str, - information_table: StormInformationTable, - article_with_outline: StormArticle, - callback_handler: BaseCallbackHandler = None, - ) -> StormArticle: + collected_info = information_table.retrieve_information(queries=section_query, + search_top_k=self.retrieve_top_k) + output = self.section_gen(topic=topic, + outline=section_outline, + section=section_name, + collected_info=collected_info) + return {"section_name": section_name, "section_content": output.section, "collected_info": collected_info} + + def generate_article(self, + topic: str, + information_table: StormInformationTable, + article_with_outline: StormArticle, + callback_handler: BaseCallbackHandler = None) -> StormArticle: """ Generate article for the topic based on the information table and article outline. @@ -76,48 +63,35 @@ def generate_article( section_output_dict_collection = [] if len(sections_to_write) == 0: - logging.error( - f"No outline for {topic}. Will directly search with the topic." - ) + logging.error(f'No outline for {topic}. Will directly search with the topic.') section_output_dict = self.generate_section( topic=topic, section_name=topic, information_table=information_table, section_outline="", - section_query=[topic], + section_query=[topic] ) section_output_dict_collection = [section_output_dict] else: - with concurrent.futures.ThreadPoolExecutor( - max_workers=self.max_thread_num - ) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_num) as executor: future_to_sec_title = {} for section_title in sections_to_write: # We don't want to write a separate introduction section. - if section_title.lower().strip() == "introduction": + if section_title.lower().strip() == 'introduction': continue # We don't want to write a separate conclusion section. if section_title.lower().strip().startswith( - "conclusion" - ) or section_title.lower().strip().startswith("summary"): + 'conclusion') or section_title.lower().strip().startswith('summary'): continue - section_query = article_with_outline.get_outline_as_list( - root_section_name=section_title, add_hashtags=False - ) + section_query = article_with_outline.get_outline_as_list(root_section_name=section_title, + add_hashtags=False) queries_with_hashtags = article_with_outline.get_outline_as_list( - root_section_name=section_title, add_hashtags=True - ) + root_section_name=section_title, add_hashtags=True) section_outline = "\n".join(queries_with_hashtags) future_to_sec_title[ - executor.submit( - self.generate_section, - topic, - section_title, - information_table, - section_outline, - section_query, - ) + executor.submit(self.generate_section, + topic, section_title, information_table, section_outline, section_query) ] = section_title for future in as_completed(future_to_sec_title): @@ -125,11 +99,9 @@ def generate_article( article = copy.deepcopy(article_with_outline) for section_output_dict in section_output_dict_collection: - article.update_section( - parent_section_name=topic, - current_section_content=section_output_dict["section_content"], - current_section_info_list=section_output_dict["collected_info"], - ) + article.update_section(parent_section_name=topic, + current_section_content=section_output_dict["section_content"], + current_section_info_list=section_output_dict["collected_info"]) article.post_processing() return article @@ -142,24 +114,17 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.write_section = dspy.Predict(WriteSection) self.engine = engine - def forward( - self, - topic: str, - outline: str, - section: str, - collected_info: List[StormInformation], - ): - info = "" + def forward(self, topic: str, outline: str, section: str, collected_info: List[StormInformation]): + info = '' for idx, storm_info in enumerate(collected_info): - info += f"[{idx + 1}]\n" + "\n".join(storm_info.snippets) - info += "\n\n" + info += f'[{idx + 1}]\n' + '\n'.join(storm_info.snippets) + info += '\n\n' info = ArticleTextProcessing.limit_word_count_preserve_newline(info, 1500) with dspy.settings.context(lm=self.engine): section = ArticleTextProcessing.clean_up_section( - self.write_section(topic=topic, info=info, section=section).output - ) + self.write_section(topic=topic, info=info, section=section).output) return dspy.Prediction(section=section) @@ -167,9 +132,9 @@ def forward( class WriteSection(dspy.Signature): """Write a Wikipedia section based on the collected information. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end. """ info = dspy.InputField(prefix="The collected information:\n", format=str) @@ -177,5 +142,5 @@ class WriteSection(dspy.Signature): section = dspy.InputField(prefix="The section you need to write: ", format=str) output = dspy.OutputField( prefix="Write the section with proper inline citations (Start your writing with # section title. Don't include the page title or try to write other sections):\n", - format=str, + format=str ) diff --git a/knowledge_storm/storm_wiki/modules/article_polish.py b/knowledge_storm/storm_wiki/modules/article_polish.py index fb85b0f3..b70bb834 100644 --- a/knowledge_storm/storm_wiki/modules/article_polish.py +++ b/knowledge_storm/storm_wiki/modules/article_polish.py @@ -14,21 +14,21 @@ class StormArticlePolishingModule(ArticlePolishingModule): knowledge curation stage, generated outline from outline generation stage. """ - def __init__( - self, - article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - ): + def __init__(self, + article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.article_gen_lm = article_gen_lm self.article_polish_lm = article_polish_lm self.polish_page = PolishPageModule( - write_lead_engine=self.article_gen_lm, polish_engine=self.article_polish_lm + write_lead_engine=self.article_gen_lm, + polish_engine=self.article_polish_lm ) - def polish_article( - self, topic: str, draft_article: StormArticle, remove_duplicate: bool = False - ) -> StormArticle: + def polish_article(self, + topic: str, + draft_article: StormArticle, + remove_duplicate: bool = False) -> StormArticle: """ Polish article. @@ -39,14 +39,10 @@ def polish_article( """ article_text = draft_article.to_string() - polish_result = self.polish_page( - topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate - ) + polish_result = self.polish_page(topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate) lead_section = f"# summary\n{polish_result.lead_section}" - polished_article = "\n\n".join([lead_section, polish_result.page]) - polished_article_dict = ArticleTextProcessing.parse_article_into_dict( - polished_article - ) + polished_article = '\n\n'.join([lead_section, polish_result.page]) + polished_article_dict = ArticleTextProcessing.parse_article_into_dict(polished_article) polished_article = copy.deepcopy(draft_article) polished_article.insert_or_create_section(article_dict=polished_article_dict) polished_article.post_processing() @@ -55,10 +51,9 @@ def polish_article( class WriteLeadSection(dspy.Signature): """Write a lead section for the given Wikipedia page with the following guidelines: - 1. The lead should stand on its own as a concise overview of the article's topic. It should identify the topic, establish context, explain why the topic is notable, and summarize the most important points, including any prominent controversies. - 2. The lead section should be concise and contain no more than four well-composed paragraphs. - 3. The lead section should be carefully sourced as appropriate. Add inline citations (e.g., "Washington, D.C., is the capital of the United States.[1][3].") where necessary. - """ + 1. The lead should stand on its own as a concise overview of the article's topic. It should identify the topic, establish context, explain why the topic is notable, and summarize the most important points, including any prominent controversies. + 2. The lead section should be concise and contain no more than four well-composed paragraphs. + 3. The lead section should be carefully sourced as appropriate. Add inline citations (e.g., "Washington, D.C., is the capital of the United States.[1][3].") where necessary.""" topic = dspy.InputField(prefix="The topic of the page: ", format=str) draft_page = dspy.InputField(prefix="The draft page:\n", format=str) @@ -73,11 +68,8 @@ class PolishPage(dspy.Signature): class PolishPageModule(dspy.Module): - def __init__( - self, - write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - ): + def __init__(self, write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): super().__init__() self.write_lead_engine = write_lead_engine self.polish_engine = polish_engine @@ -86,9 +78,7 @@ def __init__( def forward(self, topic: str, draft_page: str, polish_whole_page: bool = True): with dspy.settings.context(lm=self.write_lead_engine): - lead_section = self.write_lead( - topic=topic, draft_page=draft_page - ).lead_section + lead_section = self.write_lead(topic=topic, draft_page=draft_page).lead_section if "The lead section:" in lead_section: lead_section = lead_section.split("The lead section:")[1].strip() if polish_whole_page: diff --git a/knowledge_storm/storm_wiki/modules/knowledge_curation.py b/knowledge_storm/storm_wiki/modules/knowledge_curation.py index bde27678..8e881c65 100644 --- a/knowledge_storm/storm_wiki/modules/knowledge_curation.py +++ b/knowledge_storm/storm_wiki/modules/knowledge_curation.py @@ -25,32 +25,20 @@ class ConvSimulator(dspy.Module): """Simulate a conversation between a Wikipedia writer with specific persona and an expert.""" - def __init__( - self, - topic_expert_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - question_asker_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - retriever: Retriever, - max_search_queries_per_turn: int, - search_top_k: int, - max_turn: int, - ): + def __init__(self, topic_expert_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + question_asker_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + retriever: Retriever, max_search_queries_per_turn: int, search_top_k: int, max_turn: int): super().__init__() self.wiki_writer = WikiWriter(engine=question_asker_engine) self.topic_expert = TopicExpert( engine=topic_expert_engine, max_search_queries=max_search_queries_per_turn, search_top_k=search_top_k, - retriever=retriever, + retriever=retriever ) self.max_turn = max_turn - def forward( - self, - topic: str, - persona: str, - ground_truth_url: str, - callback_handler: BaseCallbackHandler, - ): + def forward(self, topic: str, persona: str, ground_truth_url: str, callback_handler: BaseCallbackHandler): """ topic: The topic to research. persona: The persona of the Wikipedia writer. @@ -58,22 +46,18 @@ def forward( """ dlg_history: List[DialogueTurn] = [] for _ in range(self.max_turn): - user_utterance = self.wiki_writer( - topic=topic, persona=persona, dialogue_turns=dlg_history - ).question - if user_utterance == "": - logging.error("Simulated Wikipedia writer utterance is empty.") + user_utterance = self.wiki_writer(topic=topic, persona=persona, dialogue_turns=dlg_history).question + if user_utterance == '': + logging.error('Simulated Wikipedia writer utterance is empty.') break - if user_utterance.startswith("Thank you so much for your help!"): + if user_utterance.startswith('Thank you so much for your help!'): break - expert_output = self.topic_expert( - topic=topic, question=user_utterance, ground_truth_url=ground_truth_url - ) + expert_output = self.topic_expert(topic=topic, question=user_utterance, ground_truth_url=ground_truth_url) dlg_turn = DialogueTurn( agent_utterance=expert_output.answer, user_utterance=user_utterance, search_queries=expert_output.queries, - search_results=expert_output.searched_results, + search_results=expert_output.searched_results ) dlg_history.append(dlg_turn) callback_handler.on_dialogue_turn_end(dlg_turn=dlg_turn) @@ -92,35 +76,22 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.ask_question = dspy.ChainOfThought(AskQuestion) self.engine = engine - def forward( - self, - topic: str, - persona: str, - dialogue_turns: List[DialogueTurn], - draft_page=None, - ): + def forward(self, topic: str, persona: str, dialogue_turns: List[DialogueTurn], draft_page=None): conv = [] for turn in dialogue_turns[:-4]: - conv.append( - f"You: {turn.user_utterance}\nExpert: Omit the answer here due to space limit." - ) + conv.append(f'You: {turn.user_utterance}\nExpert: Omit the answer here due to space limit.') for turn in dialogue_turns[-4:]: conv.append( - f"You: {turn.user_utterance}\nExpert: {ArticleTextProcessing.remove_citations(turn.agent_utterance)}" - ) - conv = "\n".join(conv) - conv = conv.strip() or "N/A" + f'You: {turn.user_utterance}\nExpert: {ArticleTextProcessing.remove_citations(turn.agent_utterance)}') + conv = '\n'.join(conv) + conv = conv.strip() or 'N/A' conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 2500) with dspy.settings.context(lm=self.engine): if persona is not None and len(persona.strip()) > 0: - question = self.ask_question_with_persona( - topic=topic, persona=persona, conv=conv - ).question + question = self.ask_question_with_persona(topic=topic, persona=persona, conv=conv).question else: - question = self.ask_question( - topic=topic, persona=persona, conv=conv - ).question + question = self.ask_question(topic=topic, persona=persona, conv=conv).question return dspy.Prediction(question=question) @@ -128,11 +99,10 @@ def forward( class AskQuestion(dspy.Signature): """You are an experienced Wikipedia writer. You are chatting with an expert to get information for the topic you want to contribute. Ask good questions to get more useful information relevant to the topic. When you have no more question to ask, say "Thank you so much for your help!" to end the conversation. - Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write. - """ + Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write.""" - topic = dspy.InputField(prefix="Topic you want to write: ", format=str) - conv = dspy.InputField(prefix="Conversation history:\n", format=str) + topic = dspy.InputField(prefix='Topic you want to write: ', format=str) + conv = dspy.InputField(prefix='Conversation history:\n', format=str) question = dspy.OutputField(format=str) @@ -140,41 +110,38 @@ class AskQuestionWithPersona(dspy.Signature): """You are an experienced Wikipedia writer and want to edit a specific page. Besides your identity as a Wikipedia writer, you have specific focus when researching the topic. Now, you are chatting with an expert to get information. Ask good questions to get more useful information. When you have no more question to ask, say "Thank you so much for your help!" to end the conversation. - Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write. - """ + Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write.""" - topic = dspy.InputField(prefix="Topic you want to write: ", format=str) - persona = dspy.InputField( - prefix="Your persona besides being a Wikipedia writer: ", format=str - ) - conv = dspy.InputField(prefix="Conversation history:\n", format=str) + topic = dspy.InputField(prefix='Topic you want to write: ', format=str) + persona = dspy.InputField(prefix='Your persona besides being a Wikipedia writer: ', format=str) + conv = dspy.InputField(prefix='Conversation history:\n', format=str) question = dspy.OutputField(format=str) class QuestionToQuery(dspy.Signature): """You want to answer the question using Google search. What do you type in the search box? - Write the queries you will use in the following format: - - query 1 - - query 2 - ... - - query n""" - - topic = dspy.InputField(prefix="Topic you are discussing about: ", format=str) - question = dspy.InputField(prefix="Question you want to answer: ", format=str) + Write the queries you will use in the following format: + - query 1 + - query 2 + ... + - query n""" + + topic = dspy.InputField(prefix='Topic you are discussing about: ', format=str) + question = dspy.InputField(prefix='Question you want to answer: ', format=str) queries = dspy.OutputField(format=str) class AnswerQuestion(dspy.Signature): """You are an expert who can use information effectively. You are chatting with a Wikipedia writer who wants to write a Wikipedia page on topic you know. You have gathered the related information and will now use the information to form a response. - Make your response as informative as possible and make sure every sentence is supported by the gathered information. If [Gathered information] is not related to he [Topic] and [Question], output "Sorry, I don't have enough information to answer the question.". - """ + Make your response as informative as possible and make sure every sentence is supported by the gathered information. If [Gathered information] is not related to he [Topic] and [Question], output "Sorry, I don't have enough information to answer the question.".""" - topic = dspy.InputField(prefix="Topic you are discussing about:", format=str) - conv = dspy.InputField(prefix="Question:\n", format=str) - info = dspy.InputField(prefix="Gathered information:\n", format=str) + topic = dspy.InputField(prefix='Topic you are discussing about:', format=str) + conv = dspy.InputField(prefix='Question:\n', format=str) + info = dspy.InputField( + prefix='Gathered information:\n', format=str) answer = dspy.OutputField( - prefix="Now give your response. (Try to use as many different sources as possible and add do not hallucinate.)\n", - format=str, + prefix='Now give your response. (Try to use as many different sources as possible and add do not hallucinate.)\n', + format=str ) @@ -186,13 +153,8 @@ class TopicExpert(dspy.Module): 4. Generate an answer using the retrieved information. """ - def __init__( - self, - engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - max_search_queries: int, - search_top_k: int, - retriever: Retriever, - ): + def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + max_search_queries: int, search_top_k: int, retriever: Retriever): super().__init__() self.generate_queries = dspy.Predict(QuestionToQuery) self.retriever = retriever @@ -206,43 +168,31 @@ def forward(self, topic: str, question: str, ground_truth_url: str): with dspy.settings.context(lm=self.engine): # Identify: Break down question into queries. queries = self.generate_queries(topic=topic, question=question).queries - queries = [ - q.replace("-", "").strip().strip('"').strip('"').strip() - for q in queries.split("\n") - ] - queries = queries[: self.max_search_queries] + queries = [q.replace('-', '').strip().strip('"').strip('"').strip() for q in queries.split('\n')] + queries = queries[:self.max_search_queries] # Search - searched_results: List[StormInformation] = self.retriever.retrieve( - list(set(queries)), exclude_urls=[ground_truth_url] - ) + searched_results: List[StormInformation] = self.retriever.retrieve(list(set(queries)), + exclude_urls=[ground_truth_url]) if len(searched_results) > 0: # Evaluate: Simplify this part by directly using the top 1 snippet. - info = "" + info = '' for n, r in enumerate(searched_results): - info += "\n".join(f"[{n + 1}]: {s}" for s in r.snippets[:1]) - info += "\n\n" + info += '\n'.join(f'[{n + 1}]: {s}' for s in r.snippets[:1]) + info += '\n\n' - info = ArticleTextProcessing.limit_word_count_preserve_newline( - info, 1000 - ) + info = ArticleTextProcessing.limit_word_count_preserve_newline(info, 1000) try: - answer = self.answer_question( - topic=topic, conv=question, info=info - ).answer - answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( - answer - ) + answer = self.answer_question(topic=topic, conv=question, info=info).answer + answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(answer) except Exception as e: - logging.error(f"Error occurs when generating answer: {e}") - answer = "Sorry, I cannot answer this question. Please ask another question." + logging.error(f'Error occurs when generating answer: {e}') + answer = 'Sorry, I cannot answer this question. Please ask another question.' else: # When no information is found, the expert shouldn't hallucinate. - answer = "Sorry, I cannot find information for this question. Please ask another question." + answer = 'Sorry, I cannot find information for this question. Please ask another question.' - return dspy.Prediction( - queries=queries, searched_results=searched_results, answer=answer - ) + return dspy.Prediction(queries=queries, searched_results=searched_results, answer=answer) class StormKnowledgeCurationModule(KnowledgeCurationModule): @@ -250,17 +200,15 @@ class StormKnowledgeCurationModule(KnowledgeCurationModule): The interface for knowledge curation stage. Given topic, return collected information. """ - def __init__( - self, - retriever: Retriever, - persona_generator: Optional[StormPersonaGenerator], - conv_simulator_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - question_asker_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - max_search_queries_per_turn: int, - search_top_k: int, - max_conv_turn: int, - max_thread_num: int, - ): + def __init__(self, + retriever: Retriever, + persona_generator: Optional[StormPersonaGenerator], + conv_simulator_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + question_asker_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + max_search_queries_per_turn: int, + search_top_k: int, + max_conv_turn: int, + max_thread_num: int): """ Store args and finish initialization. """ @@ -276,22 +224,14 @@ def __init__( retriever=retriever, max_search_queries_per_turn=max_search_queries_per_turn, search_top_k=search_top_k, - max_turn=max_conv_turn, + max_turn=max_conv_turn ) def _get_considered_personas(self, topic: str, max_num_persona) -> List[str]: - return self.persona_generator.generate_persona( - topic=topic, max_num_persona=max_num_persona - ) + return self.persona_generator.generate_persona(topic=topic, max_num_persona=max_num_persona) - def _run_conversation( - self, - conv_simulator, - topic, - ground_truth_url, - considered_personas, - callback_handler: BaseCallbackHandler, - ) -> List[Tuple[str, List[DialogueTurn]]]: + def _run_conversation(self, conv_simulator, topic, ground_truth_url, considered_personas, + callback_handler: BaseCallbackHandler) -> List[Tuple[str, List[DialogueTurn]]]: """ Executes multiple conversation simulations concurrently, each with a different persona, and collects their dialog histories. The dialog history of each conversation is cleaned @@ -320,16 +260,13 @@ def run_conv(persona): topic=topic, ground_truth_url=ground_truth_url, persona=persona, - callback_handler=callback_handler, + callback_handler=callback_handler ) max_workers = min(self.max_thread_num, len(considered_personas)) with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_persona = { - executor.submit(run_conv, persona): persona - for persona in considered_personas - } + future_to_persona = {executor.submit(run_conv, persona): persona for persona in considered_personas} if streamlit_connection: # Ensure the logging context is correct when connecting with Streamlit frontend. @@ -339,27 +276,23 @@ def run_conv(persona): for future in as_completed(future_to_persona): persona = future_to_persona[future] conv = future.result() - conversations.append( - (persona, ArticleTextProcessing.clean_up_citation(conv).dlg_history) - ) + conversations.append((persona, ArticleTextProcessing.clean_up_citation(conv).dlg_history)) return conversations - def research( - self, - topic: str, - ground_truth_url: str, - callback_handler: BaseCallbackHandler, - max_perspective: int = 0, - disable_perspective: bool = True, - return_conversation_log=False, - ) -> Union[StormInformationTable, Tuple[StormInformationTable, Dict]]: + def research(self, + topic: str, + ground_truth_url: str, + callback_handler: BaseCallbackHandler, + max_perspective: int = 0, + disable_perspective: bool = True, + return_conversation_log=False) -> Union[StormInformationTable, Tuple[StormInformationTable, Dict]]: """ Curate information and knowledge for the given topic Args: topic: topic of interest in natural language. - + Returns: collected_information: collected information in InformationTable type. """ @@ -370,25 +303,19 @@ def research( if disable_perspective: considered_personas = [""] else: - considered_personas = self._get_considered_personas( - topic=topic, max_num_persona=max_perspective - ) + considered_personas = self._get_considered_personas(topic=topic, max_num_persona=max_perspective) callback_handler.on_identify_perspective_end(perspectives=considered_personas) - # run conversation + # run conversation callback_handler.on_information_gathering_start() - conversations = self._run_conversation( - conv_simulator=self.conv_simulator, - topic=topic, - ground_truth_url=ground_truth_url, - considered_personas=considered_personas, - callback_handler=callback_handler, - ) + conversations = self._run_conversation(conv_simulator=self.conv_simulator, + topic=topic, + ground_truth_url=ground_truth_url, + considered_personas=considered_personas, + callback_handler=callback_handler) information_table = StormInformationTable(conversations) callback_handler.on_information_gathering_end() if return_conversation_log: - return information_table, StormInformationTable.construct_log_dict( - conversations - ) + return information_table, StormInformationTable.construct_log_dict(conversations) return information_table diff --git a/knowledge_storm/storm_wiki/modules/outline_generation.py b/knowledge_storm/storm_wiki/modules/outline_generation.py index a96c7978..1f45b1c2 100644 --- a/knowledge_storm/storm_wiki/modules/outline_generation.py +++ b/knowledge_storm/storm_wiki/modules/outline_generation.py @@ -14,19 +14,18 @@ class StormOutlineGenerationModule(OutlineGenerationModule): curation stage, generate outline for the article. """ - def __init__(self, outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + def __init__(self, + outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): super().__init__() self.outline_gen_lm = outline_gen_lm self.write_outline = WriteOutline(engine=self.outline_gen_lm) - def generate_outline( - self, - topic: str, - information_table: StormInformationTable, - old_outline: Optional[StormArticle] = None, - callback_handler: BaseCallbackHandler = None, - return_draft_outline=False, - ) -> Union[StormArticle, Tuple[StormArticle, StormArticle]]: + def generate_outline(self, + topic: str, + information_table: StormInformationTable, + old_outline: Optional[StormArticle] = None, + callback_handler: BaseCallbackHandler = None, + return_draft_outline=False) -> Union[StormArticle, Tuple[StormArticle, StormArticle]]: """ Generates an outline for an article based on the specified topic and the information gathered during the knowledge curation stage. This method can optionally return both the @@ -35,38 +34,30 @@ def generate_outline( Args: topic (str): The topic of the article. information_table (StormInformationTable): The information table containing the collected information. - old_outline (Optional[StormArticle]): An optional previous version of the article outline that can + old_outline (Optional[StormArticle]): An optional previous version of the article outline that can be used for reference or comparison. Defaults to None. - callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger - custom callbacks at various stages of the outline generation process, such as when the information + callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger + custom callbacks at various stages of the outline generation process, such as when the information organization starts. Defaults to None. - return_draft_outline (bool): A flag indicating whether the method should return both the final article - outline and a draft version of the outline. If False, only the final article outline is returned. + return_draft_outline (bool): A flag indicating whether the method should return both the final article + outline and a draft version of the outline. If False, only the final article outline is returned. Defaults to False. Returns: - Union[StormArticle, Tuple[StormArticle, StormArticle]]: Depending on the value of `return_draft_outline`, - this method returns either a single `StormArticle` object containing the final outline or a tuple of - two `StormArticle` objects, the first containing the final outline and the second containing the + Union[StormArticle, Tuple[StormArticle, StormArticle]]: Depending on the value of `return_draft_outline`, + this method returns either a single `StormArticle` object containing the final outline or a tuple of + two `StormArticle` objects, the first containing the final outline and the second containing the draft outline. """ if callback_handler is not None: callback_handler.on_information_organization_start() - concatenated_dialogue_turns = sum( - [conv for (_, conv) in information_table.conversations], [] - ) - result = self.write_outline( - topic=topic, - dlg_history=concatenated_dialogue_turns, - callback_handler=callback_handler, - ) - article_with_outline_only = StormArticle.from_outline_str( - topic=topic, outline_str=result.outline - ) - article_with_draft_outline_only = StormArticle.from_outline_str( - topic=topic, outline_str=result.old_outline - ) + concatenated_dialogue_turns = sum([conv for (_, conv) in information_table.conversations], []) + result = self.write_outline(topic=topic, dlg_history=concatenated_dialogue_turns, + callback_handler=callback_handler) + article_with_outline_only = StormArticle.from_outline_str(topic=topic, outline_str=result.outline) + article_with_draft_outline_only = StormArticle.from_outline_str(topic=topic, + outline_str=result.old_outline) if not return_draft_outline: return article_with_outline_only return article_with_outline_only, article_with_draft_outline_only @@ -81,44 +72,25 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.write_page_outline = dspy.Predict(WritePageOutlineFromConv) self.engine = engine - def forward( - self, - topic: str, - dlg_history, - old_outline: Optional[str] = None, - callback_handler: BaseCallbackHandler = None, - ): + def forward(self, topic: str, dlg_history, old_outline: Optional[str] = None, + callback_handler: BaseCallbackHandler = None): trimmed_dlg_history = [] for turn in dlg_history: - if ( - "topic you" in turn.agent_utterance.lower() - or "topic you" in turn.user_utterance.lower() - ): + if 'topic you' in turn.agent_utterance.lower() or 'topic you' in turn.user_utterance.lower(): continue trimmed_dlg_history.append(turn) - conv = "\n".join( - [ - f"Wikipedia Writer: {turn.user_utterance}\nExpert: {turn.agent_utterance}" - for turn in trimmed_dlg_history - ] - ) + conv = '\n'.join([f'Wikipedia Writer: {turn.user_utterance}\nExpert: {turn.agent_utterance}' for turn in + trimmed_dlg_history]) conv = ArticleTextProcessing.remove_citations(conv) conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 5000) with dspy.settings.context(lm=self.engine): if old_outline is None: - old_outline = ArticleTextProcessing.clean_up_outline( - self.draft_page_outline(topic=topic).outline - ) + old_outline = ArticleTextProcessing.clean_up_outline(self.draft_page_outline(topic=topic).outline) if callback_handler: - callback_handler.on_direct_outline_generation_end( - outline=old_outline - ) + callback_handler.on_direct_outline_generation_end(outline=old_outline) outline = ArticleTextProcessing.clean_up_outline( - self.write_page_outline( - topic=topic, old_outline=old_outline, conv=conv - ).outline - ) + self.write_page_outline(topic=topic, old_outline=old_outline, conv=conv).outline) if callback_handler: callback_handler.on_outline_refinement_end(outline=outline) @@ -127,10 +99,10 @@ def forward( class WritePageOutline(dspy.Signature): """Write an outline for a Wikipedia page. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Do not include other information. - 3. Do not include topic name itself in the outline. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Do not include other information. + 3. Do not include topic name itself in the outline. """ topic = dspy.InputField(prefix="The topic you want to write: ", format=str) @@ -152,10 +124,10 @@ def forward(self, topic: str): class WritePageOutlineFromConv(dspy.Signature): """Improve an outline for a Wikipedia page. You already have a draft outline that covers the general information. Now you want to improve it based on the information learned from an information-seeking conversation to make it more informative. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Do not include other information. - 3. Do not include topic name itself in the outline. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Do not include other information. + 3. Do not include topic name itself in the outline. """ topic = dspy.InputField(prefix="The topic you want to write: ", format=str) @@ -163,5 +135,5 @@ class WritePageOutlineFromConv(dspy.Signature): old_outline = dspy.OutputField(prefix="Current outline:\n", format=str) outline = dspy.OutputField( prefix='Write the Wikipedia page outline (Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, ...):\n', - format=str, + format=str ) diff --git a/knowledge_storm/storm_wiki/modules/persona_generator.py b/knowledge_storm/storm_wiki/modules/persona_generator.py index c51dc0cc..5150e31b 100644 --- a/knowledge_storm/storm_wiki/modules/persona_generator.py +++ b/knowledge_storm/storm_wiki/modules/persona_generator.py @@ -11,27 +11,19 @@ def get_wiki_page_title_and_toc(url): """Get the main title and table of contents from an url of a Wikipedia page.""" response = requests.get(url) - soup = BeautifulSoup(response.content, "html.parser") + soup = BeautifulSoup(response.content, 'html.parser') # Get the main title from the first h1 tag - main_title = soup.find("h1").text.replace("[edit]", "").strip().replace("\xa0", " ") + main_title = soup.find('h1').text.replace('[edit]', '').strip().replace('\xa0', ' ') toc = "" levels = [] - excluded_sections = { - "Contents", - "See also", - "Notes", - "References", - "External links", - } + excluded_sections = {'Contents', 'See also', 'Notes', 'References', 'External links'} # Start processing from h2 to exclude the main title from TOC - for header in soup.find_all(["h2", "h3", "h4", "h5", "h6"]): - level = int( - header.name[1] - ) # Extract the numeric part of the header tag (e.g., '2' from 'h2') - section_title = header.text.replace("[edit]", "").strip().replace("\xa0", " ") + for header in soup.find_all(['h2', 'h3', "h4", "h5", "h6"]): + level = int(header.name[1]) # Extract the numeric part of the header tag (e.g., '2' from 'h2') + section_title = header.text.replace('[edit]', '').strip().replace('\xa0', ' ') if section_title in excluded_sections: continue @@ -47,9 +39,9 @@ def get_wiki_page_title_and_toc(url): class FindRelatedTopic(dspy.Signature): """I'm writing a Wikipedia page for a topic mentioned below. Please identify and recommend some Wikipedia pages on closely related subjects. I'm looking for examples that provide insights into interesting aspects commonly associated with this topic, or examples that help me understand the typical content and structure included in Wikipedia pages for similar topics. - Please list the urls in separate lines.""" + Please list the urls in separate lines.""" - topic = dspy.InputField(prefix="Topic of interest:", format=str) + topic = dspy.InputField(prefix='Topic of interest:', format=str) related_topics = dspy.OutputField(format=str) @@ -58,10 +50,8 @@ class GenPersona(dspy.Signature): Give your answer in the following format: 1. short summary of editor 1: description\n2. short summary of editor 2: description\n... """ - topic = dspy.InputField(prefix="Topic of interest:", format=str) - examples = dspy.InputField( - prefix="Wiki page outlines of related topics for inspiration:\n", format=str - ) + topic = dspy.InputField(prefix='Topic of interest:', format=str) + examples = dspy.InputField(prefix='Wiki page outlines of related topics for inspiration:\n', format=str) personas = dspy.OutputField(format=str) @@ -79,44 +69,38 @@ def forward(self, topic: str, draft=None): # Get section names from wiki pages of relevant topics for inspiration. related_topics = self.find_related_topic(topic=topic).related_topics urls = [] - for s in related_topics.split("\n"): - if "http" in s: - urls.append(s[s.find("http") :]) + for s in related_topics.split('\n'): + if 'http' in s: + urls.append(s[s.find('http'):]) examples = [] for url in urls: try: title, toc = get_wiki_page_title_and_toc(url) - examples.append(f"Title: {title}\nTable of Contents: {toc}") + examples.append(f'Title: {title}\nTable of Contents: {toc}') except Exception as e: - logging.error(f"Error occurs when processing {url}: {e}") + logging.error(f'Error occurs when processing {url}: {e}') continue if len(examples) == 0: - examples.append("N/A") - gen_persona_output = self.gen_persona( - topic=topic, examples="\n----------\n".join(examples) - ).personas + examples.append('N/A') + gen_persona_output = self.gen_persona(topic=topic, examples='\n----------\n'.join(examples)).personas personas = [] - for s in gen_persona_output.split("\n"): - match = re.search(r"\d+\.\s*(.*)", s) + for s in gen_persona_output.split('\n'): + match = re.search(r'\d+\.\s*(.*)', s) if match: personas.append(match.group(1)) sorted_personas = personas - return dspy.Prediction( - personas=personas, - raw_personas_output=sorted_personas, - related_topics=related_topics, - ) + return dspy.Prediction(personas=personas, raw_personas_output=sorted_personas, related_topics=related_topics) -class StormPersonaGenerator: +class StormPersonaGenerator(): """ A generator class for creating personas based on a given topic. - This class uses an underlying engine to generate personas tailored to the specified topic. - The generator integrates with a `CreateWriterWithPersona` instance to create diverse personas, + This class uses an underlying engine to generate personas tailored to the specified topic. + The generator integrates with a `CreateWriterWithPersona` instance to create diverse personas, including a default 'Basic fact writer' persona. Attributes: @@ -149,6 +133,6 @@ def generate_persona(self, topic: str, max_num_persona: int = 3) -> List[str]: and up to `max_num_persona` additional personas generated based on the topic. """ personas = self.create_writer_with_persona(topic=topic) - default_persona = "Basic fact writer: Basic fact writer focusing on broadly covering the basic facts about the topic." + default_persona = 'Basic fact writer: Basic fact writer focusing on broadly covering the basic facts about the topic.' considered_personas = [default_persona] + personas.personas[:max_num_persona] return considered_personas diff --git a/knowledge_storm/storm_wiki/modules/retriever.py b/knowledge_storm/storm_wiki/modules/retriever.py index 85df63ec..179ae99b 100644 --- a/knowledge_storm/storm_wiki/modules/retriever.py +++ b/knowledge_storm/storm_wiki/modules/retriever.py @@ -149,8 +149,7 @@ "WordPress.com", "Worldometer", "YouTube", - "ZDNet", -} + "ZDNet"} DEPRECATED = { "Al_Mayadeen", "ANNA_News", @@ -198,7 +197,7 @@ "VDARE", "Voltaire_Network", "WorldNetDaily", - "Zero_Hedge", + "Zero_Hedge" } BLACKLISTED = { "Advameg", @@ -219,7 +218,7 @@ "The_Points_Guy_(sponsored_content)", "Swarajya", "Veterans_Today", - "ZoomInfo", + "ZoomInfo" } @@ -238,20 +237,14 @@ class StormRetriever(Retriever): def __init__(self, rm: dspy.Retrieve, k=3): super().__init__(search_top_k=k) self._rm = rm - if hasattr(rm, "is_valid_source"): + if hasattr(rm, 'is_valid_source'): rm.is_valid_source = is_valid_wikipedia_source - def retrieve( - self, query: Union[str, List[str]], exclude_urls: List[str] = [] - ) -> List[Information]: - retrieved_data_list = self._rm( - query_or_queries=query, exclude_urls=exclude_urls - ) + def retrieve(self, query: Union[str, List[str]], exclude_urls: List[str] = []) -> List[Information]: + retrieved_data_list = self._rm(query_or_queries=query, exclude_urls=exclude_urls) for data in retrieved_data_list: - for i in range(len(data["snippets"])): + for i in range(len(data['snippets'])): # STORM generate the article with citations. We do not consider multi-hop citations. # Remove citations in the source to avoid confusion. - data["snippets"][i] = ArticleTextProcessing.remove_citations( - data["snippets"][i] - ) + data['snippets'][i] = ArticleTextProcessing.remove_citations(data['snippets'][i]) return [StormInformation.from_dict(data) for data in retrieved_data_list] diff --git a/knowledge_storm/storm_wiki/modules/storm_dataclass.py b/knowledge_storm/storm_wiki/modules/storm_dataclass.py index 43826ecc..4f54ec46 100644 --- a/knowledge_storm/storm_wiki/modules/storm_dataclass.py +++ b/knowledge_storm/storm_wiki/modules/storm_dataclass.py @@ -51,29 +51,22 @@ def from_dict(cls, info_dict): Returns: StormInformation: An instance of StormInformation. """ - return cls( - info_dict["url"], - info_dict["description"], - info_dict["snippets"], - info_dict["title"], - ) + return cls(info_dict['url'], info_dict['description'], info_dict['snippets'], info_dict['title']) def to_dict(self): - return { - "url": self.uuid, - "description": self.description, - "snippets": self.snippets, - "title": self.title, - } + return {"url": self.uuid, + "description": self.description, + "snippets": self.snippets, + "title": self.title} class DialogueTurn: def __init__( - self, - agent_utterance: str = None, - user_utterance: str = None, - search_queries: Optional[List[str]] = None, - search_results: Optional[List[Union[StormInformation, Dict]]] = None, + self, + agent_utterance: str = None, + user_utterance: str = None, + search_queries: Optional[List[str]] = None, + search_results: Optional[List[Union[StormInformation, Dict]]] = None ): self.agent_utterance = agent_utterance self.user_utterance = user_utterance @@ -83,9 +76,7 @@ def __init__( if self.search_results: for idx in range(len(self.search_results)): if type(self.search_results[idx]) == dict: - self.search_results[idx] = StormInformation.from_dict( - self.search_results[idx] - ) + self.search_results[idx] = StormInformation.from_dict(self.search_results[idx]) def log(self): """ @@ -94,10 +85,10 @@ def log(self): return OrderedDict( { - "agent_utterance": self.agent_utterance, - "user_utterance": self.user_utterance, - "search_queries": self.search_queries, - "search_results": [data.to_dict() for data in self.search_results], + 'agent_utterance': self.agent_utterance, + 'user_utterance': self.user_utterance, + 'search_queries': self.search_queries, + 'search_results': [data.to_dict() for data in self.search_results], } ) @@ -107,7 +98,7 @@ class StormInformationTable(InformationTable): The InformationTable class serves as data class to store the information collected during KnowledgeCuration stage. - Create subclass to incorporate more information as needed. For example, + Create subclass to incorporate more information as needed. For example, in STORM paper https://arxiv.org/pdf/2402.14207.pdf, additional information would be perspective guided dialogue history. """ @@ -115,17 +106,13 @@ class StormInformationTable(InformationTable): def __init__(self, conversations=List[Tuple[str, List[DialogueTurn]]]): super().__init__() self.conversations = conversations - self.url_to_info: Dict[str, StormInformation] = ( - StormInformationTable.construct_url_to_info(self.conversations) - ) + self.url_to_info: Dict[str, StormInformation] = StormInformationTable.construct_url_to_info(self.conversations) @staticmethod - def construct_url_to_info( - conversations: List[Tuple[str, List[DialogueTurn]]] - ) -> Dict[str, StormInformation]: + def construct_url_to_info(conversations: List[Tuple[str, List[DialogueTurn]]]) -> Dict[str, StormInformation]: url_to_info = {} - for persona, conv in conversations: + for (persona, conv) in conversations: for turn in conv: for storm_info in turn.search_results: if storm_info.url in url_to_info: @@ -137,13 +124,14 @@ def construct_url_to_info( return url_to_info @staticmethod - def construct_log_dict( - conversations: List[Tuple[str, List[DialogueTurn]]] - ) -> List[Dict[str, Union[str, Any]]]: + def construct_log_dict(conversations: List[Tuple[str, List[DialogueTurn]]]) -> List[Dict[str, Union[str, Any]]]: conversation_log = [] - for persona, conv in conversations: + for (persona, conv) in conversations: conversation_log.append( - {"perspective": persona, "dlg_turns": [turn.log() for turn in conv]} + { + 'perspective': persona, + 'dlg_turns': [turn.log() for turn in conv] + } ) return conversation_log @@ -158,26 +146,22 @@ def from_conversation_log_file(cls, path): conversation_log_data = FileIOHelper.load_json(path) conversations = [] for item in conversation_log_data: - dialogue_turns = [DialogueTurn(**turn) for turn in item["dlg_turns"]] - persona = item["perspective"] + dialogue_turns = [DialogueTurn(**turn) for turn in item['dlg_turns']] + persona = item['perspective'] conversations.append((persona, dialogue_turns)) return cls(conversations) def prepare_table_for_retrieval(self): - self.encoder = SentenceTransformer("paraphrase-MiniLM-L6-v2") + self.encoder = SentenceTransformer('paraphrase-MiniLM-L6-v2') self.collected_urls = [] self.collected_snippets = [] for url, information in self.url_to_info.items(): for snippet in information.snippets: self.collected_urls.append(url) self.collected_snippets.append(snippet) - self.encoded_snippets = self.encoder.encode( - self.collected_snippets, show_progress_bar=False - ) + self.encoded_snippets = self.encoder.encode(self.collected_snippets, show_progress_bar=False) - def retrieve_information( - self, queries: Union[List[str], str], search_top_k - ) -> List[StormInformation]: + def retrieve_information(self, queries: Union[List[str], str], search_top_k) -> List[StormInformation]: selected_urls = [] selected_snippets = [] if type(queries) is str: @@ -207,13 +191,14 @@ def retrieve_information( class StormArticle(Article): def __init__(self, topic_name): super().__init__(topic_name=topic_name) - self.reference = {"url_to_unified_index": {}, "url_to_info": {}} + self.reference = { + "url_to_unified_index": {}, + "url_to_info": {} + } - def find_section( - self, node: ArticleSectionNode, name: str - ) -> Optional[ArticleSectionNode]: + def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]: """ - Return the node of the section given the section name. + Return the node of the section given the section name. Args: node: the node as the root to find. @@ -230,18 +215,17 @@ def find_section( return result return None - def _merge_new_info_to_references( - self, new_info_list: List[StormInformation], index_to_keep=None - ) -> Dict[int, int]: + def _merge_new_info_to_references(self, new_info_list: List[StormInformation], index_to_keep=None) -> Dict[ + int, int]: """ Merges new storm information into existing references and updates the citation index mapping. Args: - new_info_list (List[StormInformation]): A list of dictionaries representing new storm information. + new_info_list (List[StormInformation]): A list of dictionaries representing new storm information. index_to_keep (List[int]): A list of index of the new_info_list to keep. If none, keep all. Returns: - Dict[int, int]: A dictionary mapping the index of each storm information piece in the input list + Dict[int, int]: A dictionary mapping the index of each storm information piece in the input list to its unified citation index in the references. """ citation_idx_mapping = {} @@ -250,32 +234,20 @@ def _merge_new_info_to_references( continue url = storm_info.url if url not in self.reference["url_to_unified_index"]: - self.reference["url_to_unified_index"][url] = ( - len(self.reference["url_to_unified_index"]) + 1 - ) # The citation index starts from 1. + self.reference["url_to_unified_index"][url] = len( + self.reference["url_to_unified_index"]) + 1 # The citation index starts from 1. self.reference["url_to_info"][url] = storm_info else: existing_snippets = self.reference["url_to_info"][url].snippets existing_snippets.extend(storm_info.snippets) - self.reference["url_to_info"][url].snippets = list( - set(existing_snippets) - ) + self.reference["url_to_info"][url].snippets = list(set(existing_snippets)) citation_idx_mapping[idx + 1] = self.reference["url_to_unified_index"][ - url - ] # The citation index starts from 1. + url] # The citation index starts from 1. return citation_idx_mapping - def insert_or_create_section( - self, - article_dict: Dict[str, Dict], - parent_section_name: str = None, - trim_children=False, - ): - parent_node = ( - self.root - if parent_section_name is None - else self.find_section(self.root, parent_section_name) - ) + def insert_or_create_section(self, article_dict: Dict[str, Dict], parent_section_name: str = None, + trim_children=False): + parent_node = self.root if parent_section_name is None else self.find_section(self.root, parent_section_name) if trim_children: section_names = set(article_dict.keys()) @@ -286,83 +258,56 @@ def insert_or_create_section( for section_name, content_dict in article_dict.items(): current_section_node = self.find_section(parent_node, section_name) if current_section_node is None: - current_section_node = ArticleSectionNode( - section_name=section_name, content=content_dict["content"].strip() - ) - insert_to_front = ( - parent_node.section_name == self.root.section_name - and current_section_node.section_name == "summary" - ) - parent_node.add_child( - current_section_node, insert_to_front=insert_to_front - ) + current_section_node = ArticleSectionNode(section_name=section_name, + content=content_dict["content"].strip()) + insert_to_front = parent_node.section_name == self.root.section_name and current_section_node.section_name == "summary" + parent_node.add_child(current_section_node, insert_to_front=insert_to_front) else: current_section_node.content = content_dict["content"].strip() - self.insert_or_create_section( - article_dict=content_dict["subsections"], - parent_section_name=section_name, - trim_children=True, - ) + self.insert_or_create_section(article_dict=content_dict["subsections"], parent_section_name=section_name, + trim_children=True) - def update_section( - self, - current_section_content: str, - current_section_info_list: List[StormInformation], - parent_section_name: Optional[str] = None, - ) -> Optional[ArticleSectionNode]: + def update_section(self, + current_section_content: str, + current_section_info_list: List[StormInformation], + parent_section_name: Optional[str] = None) -> Optional[ArticleSectionNode]: """ - Add new section to the article. + Add new section to the article. Args: current_section_name: new section heading name in string format. parent_section_name: under which parent section to add the new one. Default to root. - current_section_content: optional section content. - + current_section_content: optional section content. + Returns: the ArticleSectionNode for current section if successfully created / updated. Otherwise none. """ if current_section_info_list is not None: - references = set( - [int(x) for x in re.findall(r"\[(\d+)\]", current_section_content)] - ) + references = set([int(x) for x in re.findall(r'\[(\d+)\]', current_section_content)]) # for any reference number greater than max number of references, delete the reference if len(references) > 0: max_ref_num = max(references) if max_ref_num > len(current_section_info_list): for i in range(len(current_section_info_list), max_ref_num + 1): - current_section_content = current_section_content.replace( - f"[{i}]", "" - ) + current_section_content = current_section_content.replace(f'[{i}]', '') if i in references: references.remove(i) # for any reference that is not used, trim it from current_section_info_list index_to_keep = [i - 1 for i in references] - citation_mapping = self._merge_new_info_to_references( - current_section_info_list, index_to_keep - ) - current_section_content = ArticleTextProcessing.update_citation_index( - current_section_content, citation_mapping - ) + citation_mapping = self._merge_new_info_to_references(current_section_info_list, index_to_keep) + current_section_content = ArticleTextProcessing.update_citation_index(current_section_content, + citation_mapping) if parent_section_name is None: parent_section_name = self.root.section_name - article_dict = ArticleTextProcessing.parse_article_into_dict( - current_section_content - ) - self.insert_or_create_section( - article_dict=article_dict, - parent_section_name=parent_section_name, - trim_children=False, - ) + article_dict = ArticleTextProcessing.parse_article_into_dict(current_section_content) + self.insert_or_create_section(article_dict=article_dict, parent_section_name=parent_section_name, + trim_children=False) - def get_outline_as_list( - self, - root_section_name: Optional[str] = None, - add_hashtags: bool = False, - include_root: bool = True, - ) -> List[str]: + def get_outline_as_list(self, root_section_name: Optional[str] = None, add_hashtags: bool = False, + include_root: bool = True) -> List[str]: """ Get outline of the article as a list. @@ -375,7 +320,7 @@ def get_outline_as_list( ###section1.2 ##section2 article.get_outline_as_list("section1") returns [section1, section1.1, section1.2, section2] - + Returns: list of section and subsection names. """ @@ -389,14 +334,8 @@ def get_outline_as_list( result = [] def preorder_traverse(node, level): - prefix = ( - "#" * level if add_hashtags else "" - ) # Adjust level if excluding root - result.append( - f"{prefix} {node.section_name}".strip() - if add_hashtags - else node.section_name - ) + prefix = "#" * level if add_hashtags else "" # Adjust level if excluding root + result.append(f"{prefix} {node.section_name}".strip() if add_hashtags else node.section_name) for child in node.children: preorder_traverse(child, level + 1) @@ -411,7 +350,7 @@ def preorder_traverse(node, level): def to_string(self) -> str: """ Get outline of the article as a list. - + Returns: list of section and subsection names. """ @@ -437,9 +376,7 @@ def reorder_reference_index(self): def pre_order_find_index(node): if node is not None: if node.content is not None and node.content: - ref_indices.extend( - ArticleTextProcessing.parse_citation_indices(node.content) - ) + ref_indices.extend(ArticleTextProcessing.parse_citation_indices(node.content)) for child in node.children: pre_order_find_index(child) @@ -454,9 +391,7 @@ def pre_order_find_index(node): def pre_order_update_index(node): if node is not None: if node.content is not None and node.content: - node.content = ArticleTextProcessing.update_citation_index( - node.content, ref_index_mapping - ) + node.content = ArticleTextProcessing.update_citation_index(node.content, ref_index_mapping) for child in node.children: pre_order_update_index(child) @@ -507,18 +442,18 @@ def from_outline_str(cls, topic: str, outline_str: str): instance = cls(topic) if lines: - a = lines[0].startswith("#") and lines[0].replace("#", "").strip().lower() + a = lines[0].startswith('#') and lines[0].replace('#', '').strip().lower() b = topic.lower().replace("_", " ") - adjust_level = lines[0].startswith("#") and lines[0].replace( - "#", "" - ).strip().lower() == topic.lower().replace("_", " ") + adjust_level = lines[0].startswith('#') and lines[0].replace('#', + '').strip().lower() == topic.lower().replace( + "_", " ") if adjust_level: lines = lines[1:] node_stack = [(0, instance.root)] # Stack to keep track of (level, node) for line in lines: - level = line.count("#") - adjust_level - section_name = line.replace("#", "").strip() + level = line.count('#') - adjust_level + section_name = line.replace('#', '').strip() if section_name == topic: continue @@ -552,9 +487,7 @@ def from_string(cls, topic_name: str, article_text: str, references: dict): article = cls(topic_name=topic_name) article.insert_or_create_section(article_dict=article_dict) for url in list(references["url_to_info"]): - references["url_to_info"][url] = StormInformation.from_dict( - references["url_to_info"][url] - ) + references["url_to_info"][url] = StormInformation.from_dict(references["url_to_info"][url]) article.reference = references return article diff --git a/knowledge_storm/utils.py b/knowledge_storm/utils.py index d07d067c..5cf6f457 100644 --- a/knowledge_storm/utils.py +++ b/knowledge_storm/utils.py @@ -17,7 +17,7 @@ def load_api_key(toml_file_path): try: - with open(toml_file_path, "r") as file: + with open(toml_file_path, 'r') as file: data = toml.load(file) except FileNotFoundError: print(f"File not found: {toml_file_path}", file=sys.stderr) @@ -53,19 +53,19 @@ def limit_word_count_preserve_newline(input_string, max_word_count): """ word_count = 0 - limited_string = "" + limited_string = '' - for word in input_string.split("\n"): + for word in input_string.split('\n'): line_words = word.split() for lw in line_words: if word_count < max_word_count: - limited_string += lw + " " + limited_string += lw + ' ' word_count += 1 else: break if word_count >= max_word_count: break - limited_string = limited_string.strip() + "\n" + limited_string = limited_string.strip() + '\n' return limited_string.strip() @@ -83,7 +83,7 @@ def remove_citations(s): str: The string with all citation patterns removed. """ - return re.sub(r"\[\d+(?:,\s*\d+)*\]", "", s) + return re.sub(r'\[\d+(?:,\s*\d+)*\]', '', s) @staticmethod def parse_citation_indices(s): @@ -96,7 +96,7 @@ def parse_citation_indices(s): Returns: List[int]: A list of unique citation indexes extracted from the content, in the order they appear. """ - matches = re.findall(r"\[\d+\]", s) + matches = re.findall(r'\[\d+\]', s) return [int(index[1:-1]) for index in matches] @staticmethod @@ -117,21 +117,19 @@ def remove_uncompleted_sentences_with_citations(text): # Convert citations like [1, 2, 3] to [1][2][3]. def replace_with_individual_brackets(match): - numbers = match.group(1).split(", ") - return " ".join(f"[{n}]" for n in numbers) + numbers = match.group(1).split(', ') + return ' '.join(f'[{n}]' for n in numbers) # Deduplicate and sort individual groups of citations. def deduplicate_group(match): citations = match.group(0) - unique_citations = list(set(re.findall(r"\[\d+\]", citations))) - sorted_citations = sorted( - unique_citations, key=lambda x: int(x.strip("[]")) - ) + unique_citations = list(set(re.findall(r'\[\d+\]', citations))) + sorted_citations = sorted(unique_citations, key=lambda x: int(x.strip('[]'))) # Return the sorted unique citations as a string - return "".join(sorted_citations) + return ''.join(sorted_citations) - text = re.sub(r"\[([0-9, ]+)\]", replace_with_individual_brackets, text) - text = re.sub(r"(\[\d+\])+", deduplicate_group, text) + text = re.sub(r'\[([0-9, ]+)\]', replace_with_individual_brackets, text) + text = re.sub(r'(\[\d+\])+', deduplicate_group, text) # Deprecated: Remove sentence without proper ending punctuation and citations. # Split the text into sentences (including citations). @@ -152,38 +150,29 @@ def deduplicate_group(match): # combined_sentences += ' '.join(trailing_citations) # Regex pattern to match sentence endings, including optional citation markers. - eos_pattern = r"([.!?])\s*(\[\d+\])?\s*" + eos_pattern = r'([.!?])\s*(\[\d+\])?\s*' matches = list(re.finditer(eos_pattern, text)) if matches: last_match = matches[-1] - text = text[: last_match.end()].strip() + text = text[:last_match.end()].strip() return text @staticmethod def clean_up_citation(conv): for turn in conv.dlg_history: - turn.agent_utterance = turn.agent_utterance[ - : turn.agent_utterance.find("References:") - ] - turn.agent_utterance = turn.agent_utterance[ - : turn.agent_utterance.find("Sources:") - ] - turn.agent_utterance = turn.agent_utterance.replace("Answer:", "").strip() + turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('References:')] + turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('Sources:')] + turn.agent_utterance = turn.agent_utterance.replace('Answer:', '').strip() try: - max_ref_num = max( - [int(x) for x in re.findall(r"\[(\d+)\]", turn.agent_utterance)] - ) + max_ref_num = max([int(x) for x in re.findall(r'\[(\d+)\]', turn.agent_utterance)]) except Exception as e: max_ref_num = 0 if max_ref_num > len(turn.search_results): for i in range(len(turn.search_results), max_ref_num + 1): - turn.agent_utterance = turn.agent_utterance.replace(f"[{i}]", "") - turn.agent_utterance = ( - ArticleTextProcessing.remove_uncompleted_sentences_with_citations( - turn.agent_utterance - ) - ) + turn.agent_utterance = turn.agent_utterance.replace(f'[{i}]', '') + turn.agent_utterance = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( + turn.agent_utterance) return conv @@ -192,46 +181,36 @@ def clean_up_outline(outline, topic=""): output_lines = [] current_level = 0 # To track the current section level - for line in outline.split("\n"): + for line in outline.split('\n'): stripped_line = line.strip() if topic != "" and f"# {topic.lower()}" in stripped_line.lower(): output_lines = [] # Check if the line is a section header - if stripped_line.startswith("#"): - current_level = stripped_line.count("#") + if stripped_line.startswith('#'): + current_level = stripped_line.count('#') output_lines.append(stripped_line) # Check if the line is a bullet point - elif stripped_line.startswith("-"): - subsection_header = ( - "#" * (current_level + 1) + " " + stripped_line[1:].strip() - ) + elif stripped_line.startswith('-'): + subsection_header = '#' * (current_level + 1) + ' ' + stripped_line[1:].strip() output_lines.append(subsection_header) - outline = "\n".join(output_lines) + outline = '\n'.join(output_lines) # Remove references. - outline = re.sub(r"#[#]? See also.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? See Also.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Notes.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? References.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub( - r"#[#]? External links.*?(?=##|$)", "", outline, flags=re.DOTALL - ) - outline = re.sub( - r"#[#]? External Links.*?(?=##|$)", "", outline, flags=re.DOTALL - ) - outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub( - r"#[#]? Further reading*?(?=##|$)", "", outline, flags=re.DOTALL - ) - outline = re.sub( - r"#[#]? Further Reading*?(?=##|$)", "", outline, flags=re.DOTALL - ) - outline = re.sub(r"#[#]? Summary.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? See also.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? See Also.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Notes.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? References.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? External links.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? External Links.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Further reading*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Further Reading*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Summary.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", '', outline, flags=re.DOTALL) return outline @@ -242,40 +221,34 @@ def clean_up_section(text): 2. Deduplicate individual groups of citations. 3. Remove unnecessary summary.""" - paragraphs = text.split("\n") + paragraphs = text.split('\n') output_paragraphs = [] summary_sec_flag = False for p in paragraphs: p = p.strip() if len(p) == 0: continue - if not p.startswith("#"): + if not p.startswith('#'): p = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(p) if summary_sec_flag: - if p.startswith("#"): + if p.startswith('#'): summary_sec_flag = False else: continue - if ( - p.startswith("Overall") - or p.startswith("In summary") - or p.startswith("In conclusion") - ): + if p.startswith('Overall') or p.startswith('In summary') or p.startswith('In conclusion'): continue - if "# Summary" in p or "# Conclusion" in p: + if "# Summary" in p or '# Conclusion' in p: summary_sec_flag = True continue output_paragraphs.append(p) - return "\n\n".join(output_paragraphs) # Join with '\n\n' for markdown format. + return '\n\n'.join(output_paragraphs) # Join with '\n\n' for markdown format. @staticmethod def update_citation_index(s, citation_map): """Update citation index in the string based on the citation map.""" for original_citation in citation_map: - s = s.replace( - f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__" - ) + s = s.replace(f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__") for original_citation, unify_citation in citation_map.items(): s = s.replace(f"__PLACEHOLDER_{original_citation}__", f"[{unify_citation}]") @@ -302,34 +275,34 @@ def parse_article_into_dict(input_string): A dictionary representing contains the section title as the key, and another dictionary as the value, which includes the 'content' and 'subsections' keys as described above. """ - lines = input_string.split("\n") + lines = input_string.split('\n') lines = [line for line in lines if line.strip()] - root = {"content": "", "subsections": {}} + root = {'content': '', 'subsections': {}} current_path = [(root, -1)] # (current_dict, level) for line in lines: - if line.startswith("#"): - level = line.count("#") - title = line.strip("# ").strip() - new_section = {"content": "", "subsections": {}} + if line.startswith('#'): + level = line.count('#') + title = line.strip('# ').strip() + new_section = {'content': '', 'subsections': {}} # Pop from stack until find the parent level while current_path and current_path[-1][1] >= level: current_path.pop() # Append new section to the nearest upper level's subsections - current_path[-1][0]["subsections"][title] = new_section + current_path[-1][0]['subsections'][title] = new_section current_path.append((new_section, level)) else: - current_path[-1][0]["content"] += line + "\n" + current_path[-1][0]['content'] += line + '\n' - return root["subsections"] + return root['subsections'] class FileIOHelper: @staticmethod def dump_json(obj, file_name, encoding="utf-8"): - with open(file_name, "w", encoding=encoding) as fw: + with open(file_name, 'w', encoding=encoding) as fw: json.dump(obj, fw, default=FileIOHelper.handle_non_serializable) @staticmethod @@ -338,27 +311,27 @@ def handle_non_serializable(obj): @staticmethod def load_json(file_name, encoding="utf-8"): - with open(file_name, "r", encoding=encoding) as fr: + with open(file_name, 'r', encoding=encoding) as fr: return json.load(fr) @staticmethod def write_str(s, path): - with open(path, "w") as f: + with open(path, 'w') as f: f.write(s) @staticmethod def load_str(path): - with open(path, "r") as f: - return "\n".join(f.readlines()) + with open(path, 'r') as f: + return '\n'.join(f.readlines()) @staticmethod def dump_pickle(obj, path): - with open(path, "wb") as f: + with open(path, 'wb') as f: pickle.dump(obj, f) @staticmethod def load_pickle(path): - with open(path, "rb") as f: + with open(path, 'rb') as f: return pickle.load(f) @@ -368,12 +341,7 @@ class WebPageHelper: Acknowledgement: Part of the code is adapted from https://github.com/stanford-oval/WikiChat project. """ - def __init__( - self, - min_char_count: int = 150, - snippet_chunk_size: int = 1000, - max_thread_num: int = 10, - ): + def __init__(self, min_char_count: int = 150, snippet_chunk_size: int = 1000, max_thread_num: int = 10): """ Args: min_char_count: Minimum character count for the article to be considered valid. @@ -414,9 +382,7 @@ def download_webpage(self, url: str): return None def urls_to_articles(self, urls: List[str]) -> Dict: - with concurrent.futures.ThreadPoolExecutor( - max_workers=self.max_thread_num - ) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_num) as executor: htmls = list(executor.map(self.download_webpage, urls)) articles = {} From 89be9743e94af3ed7c914d0f87532194d2abc481 Mon Sep 17 00:00:00 2001 From: zenith110 Date: Sun, 28 Jul 2024 11:29:57 -0400 Subject: [PATCH 4/7] Manually updated all examples except serper to follow the same style as main branch --- .../process_kaggle_arxiv_abstract_dataset.py | 25 +- examples/run_storm_wiki_claude.py | 147 ++++-------- examples/run_storm_wiki_deepseek.py | 173 ++++---------- examples/run_storm_wiki_gpt.py | 155 ++++-------- examples/run_storm_wiki_gpt_with_VectorRM.py | 226 ++++++------------ examples/run_storm_wiki_ollama.py | 211 ++++++---------- examples/run_storm_wiki_serper.py | 2 +- 7 files changed, 297 insertions(+), 642 deletions(-) diff --git a/examples/helper/process_kaggle_arxiv_abstract_dataset.py b/examples/helper/process_kaggle_arxiv_abstract_dataset.py index 30583c4c..1a07d062 100644 --- a/examples/helper/process_kaggle_arxiv_abstract_dataset.py +++ b/examples/helper/process_kaggle_arxiv_abstract_dataset.py @@ -8,28 +8,21 @@ if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument( - "--input-path", type=str, help="Path to arxiv_data_210930-054931.csv." - ) - parser.add_argument( - "--output-path", - type=str, - help="Path to store the csv file that is compatible with VectorRM.", - ) + parser.add_argument("--input-path", type=str, help="Path to arxiv_data_210930-054931.csv.") + parser.add_argument("--output-path", type=str, + help="Path to store the csv file that is compatible with VectorRM.") args = parser.parse_args() df = pd.read_csv(args.input_path) - print(f"The original dataset has {len(df)} samples.") + print(f'The original dataset has {len(df)} samples.') # Downsample the dataset. - df = df[df["terms"] == "['cs.CV']"] + df = df[df['terms'] == "['cs.CV']"] # Reformat the dataset to match the VectorRM input format. df.rename(columns={"abstracts": "content", "titles": "title"}, inplace=True) - df["url"] = [ - "uid_" + str(idx) for idx in range(len(df)) - ] # Ensure the url is unique. - df["description"] = "" + df['url'] = ['uid_' + str(idx) for idx in range(len(df))] # Ensure the url is unique. + df['description'] = '' - print(f"The downsampled dataset has {len(df)} samples.") - df.to_csv(args.output_path, index=False) + print(f'The downsampled dataset has {len(df)} samples.') + df.to_csv(args.output_path, index=False) \ No newline at end of file diff --git a/examples/run_storm_wiki_claude.py b/examples/run_storm_wiki_claude.py index 3d6847cb..64031f6a 100644 --- a/examples/run_storm_wiki_claude.py +++ b/examples/run_storm_wiki_claude.py @@ -19,23 +19,19 @@ import os from argparse import ArgumentParser -from knowledge_storm import ( - STORMWikiRunnerArguments, - STORMWikiRunner, - STORMWikiLMConfigs, -) +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs from knowledge_storm.lm import ClaudeModel from knowledge_storm.rm import YouRM, BingSearch from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path="secrets.toml") + load_api_key(toml_file_path='secrets.toml') lm_configs = STORMWikiLMConfigs() claude_kwargs = { - "api_key": os.getenv("ANTHROPIC_API_KEY"), - "temperature": 1.0, - "top_p": 0.9, + 'api_key': os.getenv("ANTHROPIC_API_KEY"), + 'temperature': 1.0, + 'top_p': 0.9 } # STORM is a LM system so different components can be powered by different models. @@ -43,21 +39,11 @@ def main(args): # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ClaudeModel( - model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs - ) - question_asker_lm = ClaudeModel( - model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs - ) - outline_gen_lm = ClaudeModel( - model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs - ) - article_gen_lm = ClaudeModel( - model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs - ) - article_polish_lm = ClaudeModel( - model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs - ) + conv_simulator_lm = ClaudeModel(model='claude-3-haiku-20240307', max_tokens=500, **claude_kwargs) + question_asker_lm = ClaudeModel(model='claude-3-sonnet-20240229', max_tokens=500, **claude_kwargs) + outline_gen_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=400, **claude_kwargs) + article_gen_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=700, **claude_kwargs) + article_polish_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=4000, **claude_kwargs) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -75,16 +61,14 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. - if args.retriever == "bing": - rm = BingSearch( - bing_search_api=os.getenv("BING_SEARCH_API_KEY"), k=engine_args.search_top_k - ) - elif args.retriever == "you": - rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + if args.retriever == 'bing': + rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) + elif args.retriever == 'you': + rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input("Topic: ") + topic = input('Topic: ') runner.run( topic=topic, do_research=args.do_research, @@ -96,81 +80,38 @@ def main(args): runner.summary() -if __name__ == "__main__": +if __name__ == '__main__': parser = ArgumentParser() # global arguments - parser.add_argument( - "--output-dir", - type=str, - default="./results/claude", - help="Directory to store the outputs.", - ) - parser.add_argument( - "--max-thread-num", - type=int, - default=3, - help="Maximum number of threads to use. The information seeking part and the article generation" - "part can speed up by using multiple threads. Consider reducing it if keep getting " - '"Exceed rate limit" error when calling LM API.', - ) - parser.add_argument( - "--retriever", - type=str, - choices=["bing", "you"], - help="The search engine API to use for retrieving information.", - ) + parser.add_argument('--output-dir', type=str, default='./results/claude', + help='Directory to store the outputs.') + parser.add_argument('--max-thread-num', type=int, default=3, + help='Maximum number of threads to use. The information seeking part and the article generation' + 'part can speed up by using multiple threads. Consider reducing it if keep getting ' + '"Exceed rate limit" error when calling LM API.') + parser.add_argument('--retriever', type=str, choices=['bing', 'you'], + help='The search engine API to use for retrieving information.') # stage of the pipeline - parser.add_argument( - "--do-research", - action="store_true", - help="If True, simulate conversation to research the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-outline", - action="store_true", - help="If True, generate an outline for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-article", - action="store_true", - help="If True, generate an article for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-polish-article", - action="store_true", - help="If True, polish the article by adding a summarization section and (optionally) removing " - "duplicate content.", - ) + parser.add_argument('--do-research', action='store_true', + help='If True, simulate conversation to research the topic; otherwise, load the results.') + parser.add_argument('--do-generate-outline', action='store_true', + help='If True, generate an outline for the topic; otherwise, load the results.') + parser.add_argument('--do-generate-article', action='store_true', + help='If True, generate an article for the topic; otherwise, load the results.') + parser.add_argument('--do-polish-article', action='store_true', + help='If True, polish the article by adding a summarization section and (optionally) removing ' + 'duplicate content.') # hyperparameters for the pre-writing stage - parser.add_argument( - "--max-conv-turn", - type=int, - default=3, - help="Maximum number of questions in conversational question asking.", - ) - parser.add_argument( - "--max-perspective", - type=int, - default=3, - help="Maximum number of perspectives to consider in perspective-guided question asking.", - ) - parser.add_argument( - "--search-top-k", - type=int, - default=3, - help="Top k search results to consider for each search query.", - ) + parser.add_argument('--max-conv-turn', type=int, default=3, + help='Maximum number of questions in conversational question asking.') + parser.add_argument('--max-perspective', type=int, default=3, + help='Maximum number of perspectives to consider in perspective-guided question asking.') + parser.add_argument('--search-top-k', type=int, default=3, + help='Top k search results to consider for each search query.') # hyperparameters for the writing stage - parser.add_argument( - "--retrieve-top-k", - type=int, - default=3, - help="Top k collected references for each section title.", - ) - parser.add_argument( - "--remove-duplicate", - action="store_true", - help="If True, remove duplicate content from the article.", - ) + parser.add_argument('--retrieve-top-k', type=int, default=3, + help='Top k collected references for each section title.') + parser.add_argument('--remove-duplicate', action='store_true', + help='If True, remove duplicate content from the article.') - main(parser.parse_args()) + main(parser.parse_args()) \ No newline at end of file diff --git a/examples/run_storm_wiki_deepseek.py b/examples/run_storm_wiki_deepseek.py index d159e948..2a7b2566 100644 --- a/examples/run_storm_wiki_deepseek.py +++ b/examples/run_storm_wiki_deepseek.py @@ -23,11 +23,7 @@ import logging from argparse import ArgumentParser -from knowledge_storm import ( - STORMWikiRunnerArguments, - STORMWikiRunner, - STORMWikiLMConfigs, -) +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs from knowledge_storm.lm import DeepSeekModel from knowledge_storm.rm import YouRM, BingSearch from knowledge_storm.utils import load_api_key @@ -39,10 +35,10 @@ def sanitize_topic(topic): Remove or replace characters that are not allowed in file names. """ # Replace spaces with underscores - topic = topic.replace(" ", "_") + topic = topic.replace(' ', '_') # Remove any character that isn't alphanumeric, underscore, or hyphen - topic = re.sub(r"[^a-zA-Z0-9_-]", "", topic) + topic = re.sub(r'[^a-zA-Z0-9_-]', '', topic) # Ensure the topic isn't empty after sanitization if not topic: @@ -52,35 +48,27 @@ def sanitize_topic(topic): def main(args): - load_api_key(toml_file_path="secrets.toml") + load_api_key(toml_file_path='secrets.toml') lm_configs = STORMWikiLMConfigs() # Ensure DEEPSEEK_API_KEY is set if not os.getenv("DEEPSEEK_API_KEY"): - raise ValueError( - "DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file." - ) + raise ValueError("DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file.") deepseek_kwargs = { - "api_key": os.getenv("DEEPSEEK_API_KEY"), - "api_base": os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"), - "temperature": args.temperature, - "top_p": args.top_p, + 'api_key': os.getenv("DEEPSEEK_API_KEY"), + 'api_base': os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"), + 'temperature': args.temperature, + 'top_p': args.top_p, } # DeepSeek offers two main models: 'deepseek-chat' for general tasks and 'deepseek-coder' for coding tasks # Users can choose the appropriate model based on their needs - conv_simulator_lm = DeepSeekModel( - model=args.model, max_tokens=500, **deepseek_kwargs - ) - question_asker_lm = DeepSeekModel( - model=args.model, max_tokens=500, **deepseek_kwargs - ) + conv_simulator_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs) + question_asker_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs) outline_gen_lm = DeepSeekModel(model=args.model, max_tokens=400, **deepseek_kwargs) article_gen_lm = DeepSeekModel(model=args.model, max_tokens=700, **deepseek_kwargs) - article_polish_lm = DeepSeekModel( - model=args.model, max_tokens=4000, **deepseek_kwargs - ) + article_polish_lm = DeepSeekModel(model=args.model, max_tokens=4000, **deepseek_kwargs) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -98,20 +86,16 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. - if args.retriever == "bing": - rm = BingSearch( - bing_search_api=os.getenv("BING_SEARCH_API_KEY"), k=engine_args.search_top_k - ) - elif args.retriever == "you": - rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + if args.retriever == 'bing': + rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) + elif args.retriever == 'you': + rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) else: - raise ValueError( - f"Invalid retriever: {args.retriever}. Choose either 'bing' or 'you'." - ) + raise ValueError(f"Invalid retriever: {args.retriever}. Choose either 'bing' or 'you'.") runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input("Topic: ") + topic = input('Topic: ') sanitized_topic = sanitize_topic(topic) try: @@ -130,95 +114,44 @@ def main(args): raise -if __name__ == "__main__": +if __name__ == '__main__': parser = ArgumentParser() # global arguments - parser.add_argument( - "--output-dir", - type=str, - default="./results/deepseek", - help="Directory to store the outputs.", - ) - parser.add_argument( - "--max-thread-num", - type=int, - default=3, - help="Maximum number of threads to use. The information seeking part and the article generation" - "part can speed up by using multiple threads. Consider reducing it if keep getting " - '"Exceed rate limit" error when calling LM API.', - ) - parser.add_argument( - "--retriever", - type=str, - choices=["bing", "you"], - required=True, - help="The search engine API to use for retrieving information.", - ) - parser.add_argument( - "--model", - type=str, - choices=["deepseek-chat", "deepseek-coder"], - default="deepseek-chat", - help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.', - ) - parser.add_argument( - "--temperature", type=float, default=1.0, help="Sampling temperature to use." - ) - parser.add_argument( - "--top_p", type=float, default=0.9, help="Top-p sampling parameter." - ) + parser.add_argument('--output-dir', type=str, default='./results/deepseek', + help='Directory to store the outputs.') + parser.add_argument('--max-thread-num', type=int, default=3, + help='Maximum number of threads to use. The information seeking part and the article generation' + 'part can speed up by using multiple threads. Consider reducing it if keep getting ' + '"Exceed rate limit" error when calling LM API.') + parser.add_argument('--retriever', type=str, choices=['bing', 'you'], required=True, + help='The search engine API to use for retrieving information.') + parser.add_argument('--model', type=str, choices=['deepseek-chat', 'deepseek-coder'], default='deepseek-chat', + help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.') + parser.add_argument('--temperature', type=float, default=1.0, + help='Sampling temperature to use.') + parser.add_argument('--top_p', type=float, default=0.9, + help='Top-p sampling parameter.') # stage of the pipeline - parser.add_argument( - "--do-research", - action="store_true", - help="If True, simulate conversation to research the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-outline", - action="store_true", - help="If True, generate an outline for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-article", - action="store_true", - help="If True, generate an article for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-polish-article", - action="store_true", - help="If True, polish the article by adding a summarization section and (optionally) removing " - "duplicate content.", - ) + parser.add_argument('--do-research', action='store_true', + help='If True, simulate conversation to research the topic; otherwise, load the results.') + parser.add_argument('--do-generate-outline', action='store_true', + help='If True, generate an outline for the topic; otherwise, load the results.') + parser.add_argument('--do-generate-article', action='store_true', + help='If True, generate an article for the topic; otherwise, load the results.') + parser.add_argument('--do-polish-article', action='store_true', + help='If True, polish the article by adding a summarization section and (optionally) removing ' + 'duplicate content.') # hyperparameters for the pre-writing stage - parser.add_argument( - "--max-conv-turn", - type=int, - default=3, - help="Maximum number of questions in conversational question asking.", - ) - parser.add_argument( - "--max-perspective", - type=int, - default=3, - help="Maximum number of perspectives to consider in perspective-guided question asking.", - ) - parser.add_argument( - "--search-top-k", - type=int, - default=3, - help="Top k search results to consider for each search query.", - ) + parser.add_argument('--max-conv-turn', type=int, default=3, + help='Maximum number of questions in conversational question asking.') + parser.add_argument('--max-perspective', type=int, default=3, + help='Maximum number of perspectives to consider in perspective-guided question asking.') + parser.add_argument('--search-top-k', type=int, default=3, + help='Top k search results to consider for each search query.') # hyperparameters for the writing stage - parser.add_argument( - "--retrieve-top-k", - type=int, - default=3, - help="Top k collected references for each section title.", - ) - parser.add_argument( - "--remove-duplicate", - action="store_true", - help="If True, remove duplicate content from the article.", - ) + parser.add_argument('--retrieve-top-k', type=int, default=3, + help='Top k collected references for each section title.') + parser.add_argument('--remove-duplicate', action='store_true', + help='If True, remove duplicate content from the article.') - main(parser.parse_args()) + main(parser.parse_args()) \ No newline at end of file diff --git a/examples/run_storm_wiki_gpt.py b/examples/run_storm_wiki_gpt.py index b97c1c47..5e4dda3a 100644 --- a/examples/run_storm_wiki_gpt.py +++ b/examples/run_storm_wiki_gpt.py @@ -22,54 +22,40 @@ import os from argparse import ArgumentParser -from knowledge_storm import ( - STORMWikiRunnerArguments, - STORMWikiRunner, - STORMWikiLMConfigs, -) +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel from knowledge_storm.rm import YouRM, BingSearch from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path="secrets.toml") + load_api_key(toml_file_path='secrets.toml') lm_configs = STORMWikiLMConfigs() openai_kwargs = { - "api_key": os.getenv("OPENAI_API_KEY"), - "temperature": 1.0, - "top_p": 0.9, + 'api_key': os.getenv("OPENAI_API_KEY"), + 'temperature': 1.0, + 'top_p': 0.9, } - ModelClass = ( - OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel - ) + ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel # If you are using Azure service, make sure the model name matches your own deployed model name. # The default name here is only used for demonstration and may not match your case. - gpt_35_model_name = ( - "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo" - ) - gpt_4_model_name = "gpt-4o" - if os.getenv("OPENAI_API_TYPE") == "azure": - openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") - openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") + gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo' + gpt_4_model_name = 'gpt-4o' + if os.getenv('OPENAI_API_TYPE') == 'azure': + openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') + openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') # STORM is a LM system so different components can be powered by different models. # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ModelClass( - model=gpt_35_model_name, max_tokens=500, **openai_kwargs - ) - question_asker_lm = ModelClass( - model=gpt_35_model_name, max_tokens=500, **openai_kwargs - ) + conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) - article_polish_lm = ModelClass( - model=gpt_4_model_name, max_tokens=4000, **openai_kwargs - ) + article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -87,16 +73,14 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. - if args.retriever == "bing": - rm = BingSearch( - bing_search_api=os.getenv("BING_SEARCH_API_KEY"), k=engine_args.search_top_k - ) - elif args.retriever == "you": - rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + if args.retriever == 'bing': + rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) + elif args.retriever == 'you': + rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input("Topic: ") + topic = input('Topic: ') runner.run( topic=topic, do_research=args.do_research, @@ -108,81 +92,38 @@ def main(args): runner.summary() -if __name__ == "__main__": +if __name__ == '__main__': parser = ArgumentParser() # global arguments - parser.add_argument( - "--output-dir", - type=str, - default="./results/gpt", - help="Directory to store the outputs.", - ) - parser.add_argument( - "--max-thread-num", - type=int, - default=3, - help="Maximum number of threads to use. The information seeking part and the article generation" - "part can speed up by using multiple threads. Consider reducing it if keep getting " - '"Exceed rate limit" error when calling LM API.', - ) - parser.add_argument( - "--retriever", - type=str, - choices=["bing", "you"], - help="The search engine API to use for retrieving information.", - ) + parser.add_argument('--output-dir', type=str, default='./results/gpt', + help='Directory to store the outputs.') + parser.add_argument('--max-thread-num', type=int, default=3, + help='Maximum number of threads to use. The information seeking part and the article generation' + 'part can speed up by using multiple threads. Consider reducing it if keep getting ' + '"Exceed rate limit" error when calling LM API.') + parser.add_argument('--retriever', type=str, choices=['bing', 'you'], + help='The search engine API to use for retrieving information.') # stage of the pipeline - parser.add_argument( - "--do-research", - action="store_true", - help="If True, simulate conversation to research the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-outline", - action="store_true", - help="If True, generate an outline for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-article", - action="store_true", - help="If True, generate an article for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-polish-article", - action="store_true", - help="If True, polish the article by adding a summarization section and (optionally) removing " - "duplicate content.", - ) + parser.add_argument('--do-research', action='store_true', + help='If True, simulate conversation to research the topic; otherwise, load the results.') + parser.add_argument('--do-generate-outline', action='store_true', + help='If True, generate an outline for the topic; otherwise, load the results.') + parser.add_argument('--do-generate-article', action='store_true', + help='If True, generate an article for the topic; otherwise, load the results.') + parser.add_argument('--do-polish-article', action='store_true', + help='If True, polish the article by adding a summarization section and (optionally) removing ' + 'duplicate content.') # hyperparameters for the pre-writing stage - parser.add_argument( - "--max-conv-turn", - type=int, - default=3, - help="Maximum number of questions in conversational question asking.", - ) - parser.add_argument( - "--max-perspective", - type=int, - default=3, - help="Maximum number of perspectives to consider in perspective-guided question asking.", - ) - parser.add_argument( - "--search-top-k", - type=int, - default=3, - help="Top k search results to consider for each search query.", - ) + parser.add_argument('--max-conv-turn', type=int, default=3, + help='Maximum number of questions in conversational question asking.') + parser.add_argument('--max-perspective', type=int, default=3, + help='Maximum number of perspectives to consider in perspective-guided question asking.') + parser.add_argument('--search-top-k', type=int, default=3, + help='Top k search results to consider for each search query.') # hyperparameters for the writing stage - parser.add_argument( - "--retrieve-top-k", - type=int, - default=3, - help="Top k collected references for each section title.", - ) - parser.add_argument( - "--remove-duplicate", - action="store_true", - help="If True, remove duplicate content from the article.", - ) + parser.add_argument('--retrieve-top-k', type=int, default=3, + help='Top k collected references for each section title.') + parser.add_argument('--remove-duplicate', action='store_true', + help='If True, remove duplicate content from the article.') - main(parser.parse_args()) + main(parser.parse_args()) \ No newline at end of file diff --git a/examples/run_storm_wiki_gpt_with_VectorRM.py b/examples/run_storm_wiki_gpt_with_VectorRM.py index 6eed8444..d6662d3c 100644 --- a/examples/run_storm_wiki_gpt_with_VectorRM.py +++ b/examples/run_storm_wiki_gpt_with_VectorRM.py @@ -30,11 +30,7 @@ import sys from argparse import ArgumentParser -from knowledge_storm import ( - STORMWikiRunnerArguments, - STORMWikiRunner, - STORMWikiLMConfigs, -) +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs from knowledge_storm.rm import VectorRM from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel from knowledge_storm.utils import load_api_key @@ -42,45 +38,35 @@ def main(args): # Load API key from the specified toml file path - load_api_key(toml_file_path="secrets.toml") + load_api_key(toml_file_path='secrets.toml') # Initialize the language model configurations engine_lm_configs = STORMWikiLMConfigs() openai_kwargs = { - "api_key": os.getenv("OPENAI_API_KEY"), - "temperature": 1.0, - "top_p": 0.9, + 'api_key': os.getenv("OPENAI_API_KEY"), + 'temperature': 1.0, + 'top_p': 0.9, } - ModelClass = ( - OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel - ) + ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel # If you are using Azure service, make sure the model name matches your own deployed model name. # The default name here is only used for demonstration and may not match your case. - gpt_35_model_name = ( - "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo" - ) - gpt_4_model_name = "gpt-4o" - if os.getenv("OPENAI_API_TYPE") == "azure": - openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") - openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") + gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo' + gpt_4_model_name = 'gpt-4o' + if os.getenv('OPENAI_API_TYPE') == 'azure': + openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') + openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') # STORM is a LM system so different components can be powered by different models. - # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm + # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ModelClass( - model=gpt_35_model_name, max_tokens=500, **openai_kwargs - ) - question_asker_lm = ModelClass( - model=gpt_35_model_name, max_tokens=500, **openai_kwargs - ) + conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) - article_polish_lm = ModelClass( - model=gpt_4_model_name, max_tokens=4000, **openai_kwargs - ) + article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs) engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm) engine_lm_configs.set_question_asker_lm(question_asker_lm) @@ -98,36 +84,30 @@ def main(args): ) # Setup VectorRM to retrieve information from your own data - rm = VectorRM( - collection_name=args.collection_name, - device=args.device, - k=engine_args.search_top_k, - ) + rm = VectorRM(collection_name=args.collection_name, device=args.device, k=engine_args.search_top_k) # initialize the vector store, either online (store the db on Qdrant server) or offline (store the db locally): - if args.vector_db_mode == "offline": + if args.vector_db_mode == 'offline': rm.init_offline_vector_db(vector_store_path=args.offline_vector_db_dir) - elif args.vector_db_mode == "online": - rm.init_online_vector_db( - url=args.online_vector_db_url, api_key=os.getenv("QDRANT_API_KEY") - ) + elif args.vector_db_mode == 'online': + rm.init_online_vector_db(url=args.online_vector_db_url, api_key=os.getenv('QDRANT_API_KEY')) # Update the vector store with the documents in the csv file if args.update_vector_store: rm.update_vector_store( file_path=args.csv_file_path, - content_column="content", - title_column="title", - url_column="url", - desc_column="description", - batch_size=args.embed_batch_size, + content_column='content', + title_column='title', + url_column='url', + desc_column='description', + batch_size=args.embed_batch_size ) # Initialize the STORM Wiki Runner runner = STORMWikiRunner(engine_args, engine_lm_configs, rm) # run the pipeline - topic = input("Topic: ") + topic = input('Topic: ') runner.run( topic=topic, do_research=args.do_research, @@ -142,119 +122,51 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument( - "--output-dir", - type=str, - default="./results/gpt_retrieval", - help="Directory to store the outputs.", - ) - parser.add_argument( - "--max-thread-num", - type=int, - default=3, - help="Maximum number of threads to use. The information seeking part and the article generation" - "part can speed up by using multiple threads. Consider reducing it if keep getting " - '"Exceed rate limit" error when calling LM API.', - ) + parser.add_argument('--output-dir', type=str, default='./results/gpt_retrieval', + help='Directory to store the outputs.') + parser.add_argument('--max-thread-num', type=int, default=3, + help='Maximum number of threads to use. The information seeking part and the article generation' + 'part can speed up by using multiple threads. Consider reducing it if keep getting ' + '"Exceed rate limit" error when calling LM API.') # provide local corpus and set up vector db - parser.add_argument( - "--collection-name", - type=str, - default="my_documents", - help="The collection name for vector store.", - ) - parser.add_argument( - "--device", - type=str, - default="mps", - help="The device used to run the retrieval model (mps, cuda, cpu, etc).", - ) - parser.add_argument( - "--vector-db-mode", - type=str, - choices=["offline", "online"], - help="The mode of the Qdrant vector store (offline or online).", - ) - parser.add_argument( - "--offline-vector-db-dir", - type=str, - default="./vector_store", - help="If use offline mode, please provide the directory to store the vector store.", - ) - parser.add_argument( - "--online-vector-db-url", - type=str, - help="If use online mode, please provide the url of the Qdrant server.", - ) - parser.add_argument( - "--update-vector-store", - action="store_true", - help="If True, update the vector store with the documents in the csv file; otherwise, " - "use the existing vector store.", - ) - parser.add_argument( - "--csv-file-path", - type=str, - help="The path of the custom document corpus in CSV format. The CSV file should include " - "content, title, url, and description columns.", - ) - parser.add_argument( - "--embed-batch-size", - type=int, - default=64, - help="Batch size for embedding the documents in the csv file.", - ) + parser.add_argument('--collection-name', type=str, default="my_documents", + help='The collection name for vector store.') + parser.add_argument('--device', type=str, default="mps", + help='The device used to run the retrieval model (mps, cuda, cpu, etc).') + parser.add_argument('--vector-db-mode', type=str, choices=['offline', 'online'], + help='The mode of the Qdrant vector store (offline or online).') + parser.add_argument('--offline-vector-db-dir', type=str, default='./vector_store', + help='If use offline mode, please provide the directory to store the vector store.') + parser.add_argument('--online-vector-db-url', type=str, + help='If use online mode, please provide the url of the Qdrant server.') + parser.add_argument('--update-vector-store', action='store_true', + help='If True, update the vector store with the documents in the csv file; otherwise, ' + 'use the existing vector store.') + parser.add_argument('--csv-file-path', type=str, + help='The path of the custom document corpus in CSV format. The CSV file should include ' + 'content, title, url, and description columns.') + parser.add_argument('--embed-batch-size', type=int, default=64, + help='Batch size for embedding the documents in the csv file.') # stage of the pipeline - parser.add_argument( - "--do-research", - action="store_true", - help="If True, simulate conversation to research the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-outline", - action="store_true", - help="If True, generate an outline for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-article", - action="store_true", - help="If True, generate an article for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-polish-article", - action="store_true", - help="If True, polish the article by adding a summarization section and (optionally) removing " - "duplicate content.", - ) + parser.add_argument('--do-research', action='store_true', + help='If True, simulate conversation to research the topic; otherwise, load the results.') + parser.add_argument('--do-generate-outline', action='store_true', + help='If True, generate an outline for the topic; otherwise, load the results.') + parser.add_argument('--do-generate-article', action='store_true', + help='If True, generate an article for the topic; otherwise, load the results.') + parser.add_argument('--do-polish-article', action='store_true', + help='If True, polish the article by adding a summarization section and (optionally) removing ' + 'duplicate content.') # hyperparameters for the pre-writing stage - parser.add_argument( - "--max-conv-turn", - type=int, - default=3, - help="Maximum number of questions in conversational question asking.", - ) - parser.add_argument( - "--max-perspective", - type=int, - default=3, - help="Maximum number of perspectives to consider in perspective-guided question asking.", - ) - parser.add_argument( - "--search-top-k", - type=int, - default=3, - help="Top k search results to consider for each search query.", - ) + parser.add_argument('--max-conv-turn', type=int, default=3, + help='Maximum number of questions in conversational question asking.') + parser.add_argument('--max-perspective', type=int, default=3, + help='Maximum number of perspectives to consider in perspective-guided question asking.') + parser.add_argument('--search-top-k', type=int, default=3, + help='Top k search results to consider for each search query.') # hyperparameters for the writing stage - parser.add_argument( - "--retrieve-top-k", - type=int, - default=3, - help="Top k collected references for each section title.", - ) - parser.add_argument( - "--remove-duplicate", - action="store_true", - help="If True, remove duplicate content from the article.", - ) - main(parser.parse_args()) + parser.add_argument('--retrieve-top-k', type=int, default=3, + help='Top k collected references for each section title.') + parser.add_argument('--remove-duplicate', action='store_true', + help='If True, remove duplicate content from the article.') + main(parser.parse_args()) \ No newline at end of file diff --git a/examples/run_storm_wiki_ollama.py b/examples/run_storm_wiki_ollama.py index 2e930464..76f85dc3 100644 --- a/examples/run_storm_wiki_ollama.py +++ b/examples/run_storm_wiki_ollama.py @@ -1,8 +1,8 @@ """ -STORM Wiki pipeline powered by local model hosted by Ollama server and You.com or Bing search engine. +STORM Wiki pipeline powered by Mistral-7B-Instruct-v0.2 hosted by VLLM server and You.com search engine. You need to set up the following environment variables to run this script: - YDC_API_KEY: You.com API key; or, BING_SEARCH_API_KEY: Bing Search API key -You also need to have a Ollama server running with the llama3 model or other. Specify `--url`, `--port` and `--model` accordingly. +You also need to have a VLLM server running with the Mistral-7B-Instruct-v0.2 model. Specify `--url` and `--port` accordingly. Output will be structured as below args.output_dir/ @@ -15,42 +15,33 @@ storm_gen_article.txt # Final article generated storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) """ - import os -import sys from argparse import ArgumentParser from dspy import Example -sys.path.append("./src") -from lm import OllamaClient -from rm import YouRM, BingSearch -from storm_wiki.engine import ( - STORMWikiRunnerArguments, - STORMWikiRunner, - STORMWikiLMConfigs, -) -from utils import load_api_key +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm.lm import VLLMClient +from knowledge_storm.rm import YouRM, BingSearch +from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path="secrets.toml") + load_api_key(toml_file_path='secrets.toml') lm_configs = STORMWikiLMConfigs() - ollama_kwargs = { - "model": args.model, + mistral_kwargs = { + "model": "mistralai/Mistral-7B-Instruct-v0.2", "port": args.port, "url": args.url, - "stop": ( - "\n\n---", - ), # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. + "stop": ('\n\n---',) # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. } - conv_simulator_lm = OllamaClient(max_tokens=500, **ollama_kwargs) - question_asker_lm = OllamaClient(max_tokens=500, **ollama_kwargs) - outline_gen_lm = OllamaClient(max_tokens=400, **ollama_kwargs) - article_gen_lm = OllamaClient(max_tokens=700, **ollama_kwargs) - article_polish_lm = OllamaClient(max_tokens=4000, **ollama_kwargs) + conv_simulator_lm = VLLMClient(max_tokens=500, **mistral_kwargs) + question_asker_lm = VLLMClient(max_tokens=500, **mistral_kwargs) + outline_gen_lm = VLLMClient(max_tokens=400, **mistral_kwargs) + article_gen_lm = VLLMClient(max_tokens=700, **mistral_kwargs) + article_polish_lm = VLLMClient(max_tokens=4000, **mistral_kwargs) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -68,12 +59,10 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. - if args.retriever == "bing": - rm = BingSearch( - bing_search_api=os.getenv("BING_SEARCH_API_KEY"), k=engine_args.search_top_k - ) - elif args.retriever == "you": - rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + if args.retriever == 'bing': + rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) + elif args.retriever == 'you': + rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) runner = STORMWikiRunner(engine_args, lm_configs, rm) @@ -85,28 +74,26 @@ def main(args): find_related_topic_example = Example( topic="Knowledge Curation", related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" - "https://en.wikipedia.org/wiki/Information_science\n" - "https://en.wikipedia.org/wiki/Library_science\n", + "https://en.wikipedia.org/wiki/Information_science\n" + "https://en.wikipedia.org/wiki/Library_science\n" ) gen_persona_example = Example( topic="Knowledge Curation", examples="Title: Knowledge management\n" - "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" - "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" - " Knowledge protection methods\n Formal methods\n Informal methods\n" - " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", + "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" + "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" + " Knowledge protection methods\n Formal methods\n Informal methods\n" + " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", personas="1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\n" - "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" - "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" - "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" - "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.", + "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" + "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" + "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" + "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm." ) runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ - find_related_topic_example - ] + find_related_topic_example] runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ - gen_persona_example - ] + gen_persona_example] # A trade-off of adding one-shot example is that it will increase the input length of the prompt. Also, some # examples may be very long (e.g., an example for writing a section based on the given information), which may @@ -118,28 +105,24 @@ def main(args): topic="Example Topic", conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" - "# Section 2\n## Subsection 1\n## Subsection 2\n" - "# Section 3", + "# Section 2\n## Subsection 1\n## Subsection 2\n" + "# Section 3", outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" - "# New Section 2\n" - "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", + "# New Section 2\n" + "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3" ) - runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ - write_page_outline_example - ] + runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [write_page_outline_example] write_section_example = Example( info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", topic="Example Topic", section="Example Section", output="# Example Topic\n## Subsection 1\n" - "This is an example sentence [1]. This is another example sentence [2][3].\n" - "## Subsection 2\nThis is one more example sentence [1].", + "This is an example sentence [1]. This is another example sentence [2][3].\n" + "## Subsection 2\nThis is one more example sentence [1]." ) - runner.storm_article_generation.section_gen.write_section.demos = [ - write_section_example - ] + runner.storm_article_generation.section_gen.write_section.demos = [write_section_example] - topic = input("Topic: ") + topic = input('Topic: ') runner.run( topic=topic, do_research=args.do_research, @@ -151,90 +134,42 @@ def main(args): runner.summary() -if __name__ == "__main__": +if __name__ == '__main__': parser = ArgumentParser() # global arguments - parser.add_argument( - "--url", type=str, default="http://localhost", help="URL of the Ollama server." - ) - parser.add_argument( - "--port", type=int, default=11434, help="Port of the Ollama server." - ) - parser.add_argument( - "--model", type=str, default="llama3:latest", help="Model of the Ollama server." - ) - parser.add_argument( - "--output-dir", - type=str, - default="./results/ollama", - help="Directory to store the outputs.", - ) - parser.add_argument( - "--max-thread-num", - type=int, - default=3, - help="Maximum number of threads to use. The information seeking part and the article generation" - "part can speed up by using multiple threads. Consider reducing it if keep getting " - '"Exceed rate limit" error when calling LM API.', - ) - parser.add_argument( - "--retriever", - type=str, - choices=["bing", "you"], - help="The search engine API to use for retrieving information.", - ) + parser.add_argument('--url', type=str, default='http://localhost', + help='URL of the VLLM server.') + parser.add_argument('--port', type=int, default=8000, + help='Port of the VLLM server.') + parser.add_argument('--output-dir', type=str, default='./results/mistral_7b', + help='Directory to store the outputs.') + parser.add_argument('--max-thread-num', type=int, default=3, + help='Maximum number of threads to use. The information seeking part and the article generation' + 'part can speed up by using multiple threads. Consider reducing it if keep getting ' + '"Exceed rate limit" error when calling LM API.') + parser.add_argument('--retriever', type=str, choices=['bing', 'you'], + help='The search engine API to use for retrieving information.') # stage of the pipeline - parser.add_argument( - "--do-research", - action="store_true", - help="If True, simulate conversation to research the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-outline", - action="store_true", - help="If True, generate an outline for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-article", - action="store_true", - help="If True, generate an article for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-polish-article", - action="store_true", - help="If True, polish the article by adding a summarization section and (optionally) removing " - "duplicate content.", - ) + parser.add_argument('--do-research', action='store_true', + help='If True, simulate conversation to research the topic; otherwise, load the results.') + parser.add_argument('--do-generate-outline', action='store_true', + help='If True, generate an outline for the topic; otherwise, load the results.') + parser.add_argument('--do-generate-article', action='store_true', + help='If True, generate an article for the topic; otherwise, load the results.') + parser.add_argument('--do-polish-article', action='store_true', + help='If True, polish the article by adding a summarization section and (optionally) removing ' + 'duplicate content.') # hyperparameters for the pre-writing stage - parser.add_argument( - "--max-conv-turn", - type=int, - default=3, - help="Maximum number of questions in conversational question asking.", - ) - parser.add_argument( - "--max-perspective", - type=int, - default=3, - help="Maximum number of perspectives to consider in perspective-guided question asking.", - ) - parser.add_argument( - "--search-top-k", - type=int, - default=3, - help="Top k search results to consider for each search query.", - ) + parser.add_argument('--max-conv-turn', type=int, default=3, + help='Maximum number of questions in conversational question asking.') + parser.add_argument('--max-perspective', type=int, default=3, + help='Maximum number of perspectives to consider in perspective-guided question asking.') + parser.add_argument('--search-top-k', type=int, default=3, + help='Top k search results to consider for each search query.') # hyperparameters for the writing stage - parser.add_argument( - "--retrieve-top-k", - type=int, - default=3, - help="Top k collected references for each section title.", - ) - parser.add_argument( - "--remove-duplicate", - action="store_true", - help="If True, remove duplicate content from the article.", - ) + parser.add_argument('--retrieve-top-k', type=int, default=3, + help='Top k collected references for each section title.') + parser.add_argument('--remove-duplicate', action='store_true', + help='If True, remove duplicate content from the article.') - main(parser.parse_args()) + main(parser.parse_args()) \ No newline at end of file diff --git a/examples/run_storm_wiki_serper.py b/examples/run_storm_wiki_serper.py index 70c2ac1d..df27b98b 100644 --- a/examples/run_storm_wiki_serper.py +++ b/examples/run_storm_wiki_serper.py @@ -101,7 +101,7 @@ def main(args): parser.add_argument( "--output-dir", type=str, - default="./results/claude", + default="./results/serper", help="Directory to store the outputs.", ) parser.add_argument( From 84de6324504ff8a5f3fa3ea39dc7385c7827c3bd Mon Sep 17 00:00:00 2001 From: zenith110 Date: Sun, 28 Jul 2024 11:34:51 -0400 Subject: [PATCH 5/7] Updated ollama and mistral --- examples/run_storm_wiki_mistral.py | 182 ++++++++++------------------- examples/run_storm_wiki_ollama.py | 38 +++--- 2 files changed, 82 insertions(+), 138 deletions(-) diff --git a/examples/run_storm_wiki_mistral.py b/examples/run_storm_wiki_mistral.py index 291d2879..76f85dc3 100644 --- a/examples/run_storm_wiki_mistral.py +++ b/examples/run_storm_wiki_mistral.py @@ -15,33 +15,26 @@ storm_gen_article.txt # Final article generated storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) """ - import os from argparse import ArgumentParser from dspy import Example -from knowledge_storm import ( - STORMWikiRunnerArguments, - STORMWikiRunner, - STORMWikiLMConfigs, -) +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs from knowledge_storm.lm import VLLMClient from knowledge_storm.rm import YouRM, BingSearch from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path="secrets.toml") + load_api_key(toml_file_path='secrets.toml') lm_configs = STORMWikiLMConfigs() mistral_kwargs = { "model": "mistralai/Mistral-7B-Instruct-v0.2", "port": args.port, "url": args.url, - "stop": ( - "\n\n---", - ), # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. + "stop": ('\n\n---',) # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. } conv_simulator_lm = VLLMClient(max_tokens=500, **mistral_kwargs) @@ -66,12 +59,10 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. - if args.retriever == "bing": - rm = BingSearch( - bing_search_api=os.getenv("BING_SEARCH_API_KEY"), k=engine_args.search_top_k - ) - elif args.retriever == "you": - rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + if args.retriever == 'bing': + rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) + elif args.retriever == 'you': + rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) runner = STORMWikiRunner(engine_args, lm_configs, rm) @@ -83,28 +74,26 @@ def main(args): find_related_topic_example = Example( topic="Knowledge Curation", related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" - "https://en.wikipedia.org/wiki/Information_science\n" - "https://en.wikipedia.org/wiki/Library_science\n", + "https://en.wikipedia.org/wiki/Information_science\n" + "https://en.wikipedia.org/wiki/Library_science\n" ) gen_persona_example = Example( topic="Knowledge Curation", examples="Title: Knowledge management\n" - "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" - "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" - " Knowledge protection methods\n Formal methods\n Informal methods\n" - " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", + "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" + "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" + " Knowledge protection methods\n Formal methods\n Informal methods\n" + " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", personas="1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\n" - "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" - "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" - "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" - "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.", + "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" + "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" + "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" + "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm." ) runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ - find_related_topic_example - ] + find_related_topic_example] runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ - gen_persona_example - ] + gen_persona_example] # A trade-off of adding one-shot example is that it will increase the input length of the prompt. Also, some # examples may be very long (e.g., an example for writing a section based on the given information), which may @@ -116,28 +105,24 @@ def main(args): topic="Example Topic", conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" - "# Section 2\n## Subsection 1\n## Subsection 2\n" - "# Section 3", + "# Section 2\n## Subsection 1\n## Subsection 2\n" + "# Section 3", outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" - "# New Section 2\n" - "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", + "# New Section 2\n" + "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3" ) - runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ - write_page_outline_example - ] + runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [write_page_outline_example] write_section_example = Example( info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", topic="Example Topic", section="Example Section", output="# Example Topic\n## Subsection 1\n" - "This is an example sentence [1]. This is another example sentence [2][3].\n" - "## Subsection 2\nThis is one more example sentence [1].", + "This is an example sentence [1]. This is another example sentence [2][3].\n" + "## Subsection 2\nThis is one more example sentence [1]." ) - runner.storm_article_generation.section_gen.write_section.demos = [ - write_section_example - ] + runner.storm_article_generation.section_gen.write_section.demos = [write_section_example] - topic = input("Topic: ") + topic = input('Topic: ') runner.run( topic=topic, do_research=args.do_research, @@ -149,87 +134,42 @@ def main(args): runner.summary() -if __name__ == "__main__": +if __name__ == '__main__': parser = ArgumentParser() # global arguments - parser.add_argument( - "--url", type=str, default="http://localhost", help="URL of the VLLM server." - ) - parser.add_argument( - "--port", type=int, default=8000, help="Port of the VLLM server." - ) - parser.add_argument( - "--output-dir", - type=str, - default="./results/mistral_7b", - help="Directory to store the outputs.", - ) - parser.add_argument( - "--max-thread-num", - type=int, - default=3, - help="Maximum number of threads to use. The information seeking part and the article generation" - "part can speed up by using multiple threads. Consider reducing it if keep getting " - '"Exceed rate limit" error when calling LM API.', - ) - parser.add_argument( - "--retriever", - type=str, - choices=["bing", "you"], - help="The search engine API to use for retrieving information.", - ) + parser.add_argument('--url', type=str, default='http://localhost', + help='URL of the VLLM server.') + parser.add_argument('--port', type=int, default=8000, + help='Port of the VLLM server.') + parser.add_argument('--output-dir', type=str, default='./results/mistral_7b', + help='Directory to store the outputs.') + parser.add_argument('--max-thread-num', type=int, default=3, + help='Maximum number of threads to use. The information seeking part and the article generation' + 'part can speed up by using multiple threads. Consider reducing it if keep getting ' + '"Exceed rate limit" error when calling LM API.') + parser.add_argument('--retriever', type=str, choices=['bing', 'you'], + help='The search engine API to use for retrieving information.') # stage of the pipeline - parser.add_argument( - "--do-research", - action="store_true", - help="If True, simulate conversation to research the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-outline", - action="store_true", - help="If True, generate an outline for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-generate-article", - action="store_true", - help="If True, generate an article for the topic; otherwise, load the results.", - ) - parser.add_argument( - "--do-polish-article", - action="store_true", - help="If True, polish the article by adding a summarization section and (optionally) removing " - "duplicate content.", - ) + parser.add_argument('--do-research', action='store_true', + help='If True, simulate conversation to research the topic; otherwise, load the results.') + parser.add_argument('--do-generate-outline', action='store_true', + help='If True, generate an outline for the topic; otherwise, load the results.') + parser.add_argument('--do-generate-article', action='store_true', + help='If True, generate an article for the topic; otherwise, load the results.') + parser.add_argument('--do-polish-article', action='store_true', + help='If True, polish the article by adding a summarization section and (optionally) removing ' + 'duplicate content.') # hyperparameters for the pre-writing stage - parser.add_argument( - "--max-conv-turn", - type=int, - default=3, - help="Maximum number of questions in conversational question asking.", - ) - parser.add_argument( - "--max-perspective", - type=int, - default=3, - help="Maximum number of perspectives to consider in perspective-guided question asking.", - ) - parser.add_argument( - "--search-top-k", - type=int, - default=3, - help="Top k search results to consider for each search query.", - ) + parser.add_argument('--max-conv-turn', type=int, default=3, + help='Maximum number of questions in conversational question asking.') + parser.add_argument('--max-perspective', type=int, default=3, + help='Maximum number of perspectives to consider in perspective-guided question asking.') + parser.add_argument('--search-top-k', type=int, default=3, + help='Top k search results to consider for each search query.') # hyperparameters for the writing stage - parser.add_argument( - "--retrieve-top-k", - type=int, - default=3, - help="Top k collected references for each section title.", - ) - parser.add_argument( - "--remove-duplicate", - action="store_true", - help="If True, remove duplicate content from the article.", - ) + parser.add_argument('--retrieve-top-k', type=int, default=3, + help='Top k collected references for each section title.') + parser.add_argument('--remove-duplicate', action='store_true', + help='If True, remove duplicate content from the article.') - main(parser.parse_args()) + main(parser.parse_args()) \ No newline at end of file diff --git a/examples/run_storm_wiki_ollama.py b/examples/run_storm_wiki_ollama.py index 76f85dc3..35ba99e1 100644 --- a/examples/run_storm_wiki_ollama.py +++ b/examples/run_storm_wiki_ollama.py @@ -1,8 +1,8 @@ """ -STORM Wiki pipeline powered by Mistral-7B-Instruct-v0.2 hosted by VLLM server and You.com search engine. +STORM Wiki pipeline powered by local model hosted by Ollama server and You.com or Bing search engine. You need to set up the following environment variables to run this script: - YDC_API_KEY: You.com API key; or, BING_SEARCH_API_KEY: Bing Search API key -You also need to have a VLLM server running with the Mistral-7B-Instruct-v0.2 model. Specify `--url` and `--port` accordingly. +You also need to have a Ollama server running with the llama3 model or other. Specify `--url`, `--port` and `--model` accordingly. Output will be structured as below args.output_dir/ @@ -16,32 +16,34 @@ storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) """ import os +import sys from argparse import ArgumentParser from dspy import Example -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs -from knowledge_storm.lm import VLLMClient -from knowledge_storm.rm import YouRM, BingSearch -from knowledge_storm.utils import load_api_key +sys.path.append('./src') +from lm import OllamaClient +from rm import YouRM, BingSearch +from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from utils import load_api_key def main(args): load_api_key(toml_file_path='secrets.toml') lm_configs = STORMWikiLMConfigs() - mistral_kwargs = { - "model": "mistralai/Mistral-7B-Instruct-v0.2", + ollama_kwargs = { + "model": args.model, "port": args.port, "url": args.url, "stop": ('\n\n---',) # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. } - conv_simulator_lm = VLLMClient(max_tokens=500, **mistral_kwargs) - question_asker_lm = VLLMClient(max_tokens=500, **mistral_kwargs) - outline_gen_lm = VLLMClient(max_tokens=400, **mistral_kwargs) - article_gen_lm = VLLMClient(max_tokens=700, **mistral_kwargs) - article_polish_lm = VLLMClient(max_tokens=4000, **mistral_kwargs) + conv_simulator_lm = OllamaClient(max_tokens=500, **ollama_kwargs) + question_asker_lm = OllamaClient(max_tokens=500, **ollama_kwargs) + outline_gen_lm = OllamaClient(max_tokens=400, **ollama_kwargs) + article_gen_lm = OllamaClient(max_tokens=700, **ollama_kwargs) + article_polish_lm = OllamaClient(max_tokens=4000, **ollama_kwargs) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -138,10 +140,12 @@ def main(args): parser = ArgumentParser() # global arguments parser.add_argument('--url', type=str, default='http://localhost', - help='URL of the VLLM server.') - parser.add_argument('--port', type=int, default=8000, - help='Port of the VLLM server.') - parser.add_argument('--output-dir', type=str, default='./results/mistral_7b', + help='URL of the Ollama server.') + parser.add_argument('--port', type=int, default=11434, + help='Port of the Ollama server.') + parser.add_argument('--model', type=str, default='llama3:latest', + help='Model of the Ollama server.') + parser.add_argument('--output-dir', type=str, default='./results/ollama', help='Directory to store the outputs.') parser.add_argument('--max-thread-num', type=int, default=3, help='Maximum number of threads to use. The information seeking part and the article generation' From a1844fa4aa6fb2ab48060c7ceb25ca53c4bd5950 Mon Sep 17 00:00:00 2001 From: zenith110 Date: Mon, 29 Jul 2024 12:37:27 -0400 Subject: [PATCH 6/7] Made changes based off pr feedback --- knowledge_storm/rm.py | 255 ++++++++++++++++++++---------------------- 1 file changed, 122 insertions(+), 133 deletions(-) diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index 441ea570..3d6092dd 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -1,7 +1,6 @@ import logging import os -from typing import Callable, Union, List, Dict -from typing_extensions import Dict +from typing import Callable, Union, List import dspy import pandas as pd @@ -11,18 +10,15 @@ from langchain_qdrant import Qdrant from qdrant_client import QdrantClient, models from tqdm import tqdm -import requests -import json -from utils import WebPageHelper + +from .utils import WebPageHelper class YouRM(dspy.Retrieve): def __init__(self, ydc_api_key=None, k=3, is_valid_source: Callable = None): super().__init__(k=k) if not ydc_api_key and not os.environ.get("YDC_API_KEY"): - raise RuntimeError( - "You must supply ydc_api_key or set environment variable YDC_API_KEY" - ) + raise RuntimeError("You must supply ydc_api_key or set environment variable YDC_API_KEY") elif ydc_api_key: self.ydc_api_key = ydc_api_key else: @@ -39,11 +35,9 @@ def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {"YouRM": usage} + return {'YouRM': usage} - def forward( - self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] - ): + def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []): """Search with You.com for self.k top passages for query or queries Args: @@ -69,30 +63,21 @@ def forward( ).json() authoritative_results = [] - for r in results["hits"]: - if self.is_valid_source(r["url"]) and r["url"] not in exclude_urls: + for r in results['hits']: + if self.is_valid_source(r['url']) and r['url'] not in exclude_urls: authoritative_results.append(r) - if "hits" in results: - collected_results.extend(authoritative_results[: self.k]) + if 'hits' in results: + collected_results.extend(authoritative_results[:self.k]) except Exception as e: - logging.error(f"Error occurs when searching query {query}: {e}") + logging.error(f'Error occurs when searching query {query}: {e}') return collected_results class BingSearch(dspy.Retrieve): - def __init__( - self, - bing_search_api_key=None, - k=3, - is_valid_source: Callable = None, - min_char_count: int = 150, - snippet_chunk_size: int = 1000, - webpage_helper_max_threads=10, - mkt="en-US", - language="en", - **kwargs, - ): + def __init__(self, bing_search_api_key=None, k=3, is_valid_source: Callable = None, + min_char_count: int = 150, snippet_chunk_size: int = 1000, webpage_helper_max_threads=10, + mkt='en-US', language='en', **kwargs): """ Params: min_char_count: Minimum character count for the article to be considered valid. @@ -104,18 +89,22 @@ def __init__( super().__init__(k=k) if not bing_search_api_key and not os.environ.get("BING_SEARCH_API_KEY"): raise RuntimeError( - "You must supply bing_search_subscription_key or set environment variable BING_SEARCH_API_KEY" - ) + "You must supply bing_search_subscription_key or set environment variable BING_SEARCH_API_KEY") elif bing_search_api_key: self.bing_api_key = bing_search_api_key else: self.bing_api_key = os.environ["BING_SEARCH_API_KEY"] self.endpoint = "https://api.bing.microsoft.com/v7.0/search" - self.params = {"mkt": mkt, "setLang": language, "count": k, **kwargs} + self.params = { + 'mkt': mkt, + "setLang": language, + "count": k, + **kwargs + } self.webpage_helper = WebPageHelper( min_char_count=min_char_count, snippet_chunk_size=snippet_chunk_size, - max_thread_num=webpage_helper_max_threads, + max_thread_num=webpage_helper_max_threads ) self.usage = 0 @@ -129,11 +118,9 @@ def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {"BingSearch": usage} + return {'BingSearch': usage} - def forward( - self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] - ): + def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []): """Search with Bing for self.k top passages for query or queries Args: @@ -157,26 +144,22 @@ def forward( for query in queries: try: results = requests.get( - self.endpoint, headers=headers, params={**self.params, "q": query} + self.endpoint, + headers=headers, + params={**self.params, 'q': query} ).json() - for d in results["webPages"]["value"]: - if self.is_valid_source(d["url"]) and d["url"] not in exclude_urls: - url_to_results[d["url"]] = { - "url": d["url"], - "title": d["name"], - "description": d["snippet"], - } + for d in results['webPages']['value']: + if self.is_valid_source(d['url']) and d['url'] not in exclude_urls: + url_to_results[d['url']] = {'url': d['url'], 'title': d['name'], 'description': d['snippet']} except Exception as e: - logging.error(f"Error occurs when searching query {query}: {e}") + logging.error(f'Error occurs when searching query {query}: {e}') - valid_url_to_snippets = self.webpage_helper.urls_to_snippets( - list(url_to_results.keys()) - ) + valid_url_to_snippets = self.webpage_helper.urls_to_snippets(list(url_to_results.keys())) collected_results = [] for url in valid_url_to_snippets: r = url_to_results[url] - r["snippets"] = valid_url_to_snippets[url]["snippets"] + r['snippets'] = valid_url_to_snippets[url]['snippets'] collected_results.append(r) return collected_results @@ -194,15 +177,13 @@ class VectorRM(dspy.Retrieve): The documents should be stored in a CSV file. """ - def __init__( - self, - collection_name: str = "my_documents", - embedding_model: str = "BAAI/bge-m3", - device: str = "mps", - k: int = 3, - chunk_size: int = 500, - chunk_overlap: int = 100, - ): + def __init__(self, + collection_name: str = "my_documents", + embedding_model: str = 'BAAI/bge-m3', + device: str = "mps", + k: int = 3, + chunk_size: int = 500, + chunk_overlap: int = 100): """ Params: collection_name: Name of the Qdrant collection. @@ -218,9 +199,7 @@ def __init__( model_kwargs = {"device": device} encode_kwargs = {"normalize_embeddings": True} self.model = HuggingFaceEmbeddings( - model_name=embedding_model, - model_kwargs=model_kwargs, - encode_kwargs=encode_kwargs, + model_name=embedding_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs ) self.chunk_size = chunk_size @@ -237,24 +216,18 @@ def _check_create_collection(self): if self.client is None: raise ValueError("Qdrant client is not initialized.") if self.client.collection_exists(collection_name=f"{self.collection_name}"): - print( - f"Collection {self.collection_name} exists. Loading the collection..." - ) + print(f"Collection {self.collection_name} exists. Loading the collection...") self.qdrant = Qdrant( client=self.client, collection_name=self.collection_name, embeddings=self.model, ) else: - print( - f"Collection {self.collection_name} does not exist. Creating the collection..." - ) + print(f"Collection {self.collection_name} does not exist. Creating the collection...") # create the collection self.client.create_collection( collection_name=f"{self.collection_name}", - vectors_config=models.VectorParams( - size=1024, distance=models.Distance.COSINE - ), + vectors_config=models.VectorParams(size=1024, distance=models.Distance.COSINE), ) self.qdrant = Qdrant( client=self.client, @@ -300,13 +273,13 @@ def init_offline_vector_db(self, vector_store_path: str): raise ValueError(f"Error occurs when loading the vector store: {e}") def update_vector_store( - self, - file_path: str, - content_column: str, - title_column: str = "title", - url_column: str = "url", - desc_column: str = "description", - batch_size: int = 64, + self, + file_path: str, + content_column: str, + title_column: str = "title", + url_column: str = "url", + desc_column: str = "description", + batch_size: int = 64 ): """ Takes a CSV file where each row is a document and has columns for content, title, url, and description. @@ -323,7 +296,7 @@ def update_vector_store( if file_path is None: raise ValueError("Please provide a file path.") # check if the file is a csv file - if not file_path.endswith(".csv"): + if not file_path.endswith('.csv'): raise ValueError(f"Not valid file format. Please provide a csv file.") if content_column is None: raise ValueError("Please provide the name of the content column.") @@ -337,9 +310,7 @@ def update_vector_store( df = pd.read_csv(file_path) # check that content column exists and url column exists if content_column not in df.columns: - raise ValueError( - f"Content column {content_column} not found in the csv file." - ) + raise ValueError(f"Content column {content_column} not found in the csv file.") if url_column not in df.columns: raise ValueError(f"URL column {url_column} not found in the csv file.") @@ -347,17 +318,16 @@ def update_vector_store( Document( page_content=row[content_column], metadata={ - "title": row.get(title_column, ""), + "title": row.get(title_column, ''), "url": row[url_column], - "description": row.get(desc_column, ""), - }, + "description": row.get(desc_column, ''), + } ) - for row in df.to_dict(orient="records") + for row in df.to_dict(orient='records') ] # split the documents from langchain_text_splitters import RecursiveCharacterTextSplitter - text_splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, @@ -375,7 +345,7 @@ def update_vector_store( " ", "\u200B", # Zero-width space "", - ], + ] ) split_documents = text_splitter.split_documents(documents) @@ -393,7 +363,7 @@ def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {"VectorRM": usage} + return {'VectorRM': usage} def get_vector_count(self): """ @@ -426,63 +396,71 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st related_docs = self.qdrant.similarity_search_with_score(query, k=self.k) for i in range(len(related_docs)): doc = related_docs[i][0] - collected_results.append( - { - "description": doc.metadata["description"], - "snippets": [doc.page_content], - "title": doc.metadata["title"], - "url": doc.metadata["url"], - } - ) + collected_results.append({ + 'description': doc.metadata['description'], + 'snippets': [doc.page_content], + 'title': doc.metadata['title'], + 'url': doc.metadata['url'], + }) return collected_results - class SerperRM(dspy.Retrieve): """Retrieve information from custom queries using Serper.dev. - - To be compatible with STORM, the results should have the following fields: - - snippet: Snippets that will be used for the document - - title: The title of the document. - - url: The URL of the document. STORM use url as the unique identifier of the document, so ensure different - documents have different urls. - - description (optional): The description of the document. - """ + """ + Args: + serper_search_api_key str: Api key to run serper, can be found by creating an account on https://serper.dev/ + query_params (dict or list of dict): paramaters in dictionary or list of dictionaries that has a max size of 100 that will be used to query. + values: + q str: query that will be used with google search + type str: type that will be used for browsing google. Types are search, images, video, maps, places and many more. Refer to the playground for the types. + gl str: Country that will be focused on for the search + location str: Country where the search will originate from. All locates can be found here: https://api.serper.dev/locations. + autocorrect bool: Enable autocorrect on the queries while searching, if query is misspelled, will be updated. + results int: Max number of results per page. + page int: Max number of pages per call. + tbs str: date time range, automically set to any time by default. + qdr:h str: Date time range for past hour. + qdr:d str: Date time range for past 24 hours. + qdr:w str: Date time range for past week. + qdr:m str: Date time range for past month. + qdr:y str: Date time range for past year. + """ def __init__(self, serper_search_api_key=None, query_params=None): super().__init__() self.usage = 0 self.query_params = query_params self.serper_search_api_key = serper_search_api_key - if not self.serper_search_api_key and not os.environ.get("SERPER_API_KEY"): + if not self.serper_search_api_key and not os.environ.get('SERPER_API_KEY'): raise RuntimeError( - "You must supply a serper_search_api_key param or set environment variable SERPER_API_KEY" + 'You must supply a serper_search_api_key param or set environment variable SERPER_API_KEY' ) elif self.serper_search_api_key: self.serper_search_api_key = serper_search_api_key else: - self.serper_search_api_key = os.environ["SERPER_API_KEY"] + self.serper_search_api_key = os.environ['SERPER_API_KEY'] - self.base_url = "https://google.serper.dev" + self.base_url = 'https://google.serper.dev' def serper_runner(self, query_params): - self.search_url = f"{self.base_url}/search" + self.search_url = f'{self.base_url}/search' headers = { - "X-API-KEY": self.serper_search_api_key, - "Content-Type": "application/json", + 'X-API-KEY': self.serper_search_api_key, + 'Content-Type': 'application/json', } response = requests.request( - "POST", self.search_url, headers=headers, json=query_params + 'POST', self.search_url, headers=headers, json=query_params ) if response == None: raise RuntimeError( - f"Error had occured while running the process {process_name}.\n Error is {response.reason}, had failed with status code {response.status_code}" + f'Error had occured while running the search process.\n Error is {response.reason}, had failed with status code {response.status_code}' ) return response.json() @@ -490,18 +468,19 @@ def serper_runner(self, query_params): def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {"SerperRM": usage} + return {'SerperRM': usage} def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]): """ Calls the API and searches for the query passed in. - + + Args: query_or_queries (Union[str, List[str]]): The query or queries to search for. exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect. Returns: - a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url' + a list of dictionaries, each dictionary has keys of 'description', 'snippets' (list of strings), 'title', 'url' """ queries = ( [query_or_queries] @@ -513,43 +492,53 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st self.results = [] collected_results = [] for query in queries: - if query == "Queries:": + if query == 'Queries:': continue query_params = self.query_params - query_params["q"] = query - query_params["type"] = "search" + """ + All paramaters can be found in the playground: https://serper.dev/playground + """ + # Sets the json value for query to be the query that is being parsed. + query_params['q'] = query + + # Sets the type to be search, can be images, video, places, maps etc that Google provides. + query_params['type'] = 'search' + self.result = self.serper_runner(query_params) self.results.append(self.result) + # Array of dictionaries that will be used by Storm to create the jsons collected_results = [] for result in self.results: try: - organic_results = result.get("organic") + # An array of dictionaries that contains the snippets, title of the document and url that will be used. + organic_results = result.get('organic') - knowledge_graph = result.get("knowledgeGraph") + knowledge_graph = result.get('knowledgeGraph') for organic in organic_results: snippets = [] - snippets.append(organic.get("snippet")) + snippets.append(organic.get('snippet')) if knowledge_graph != None: collected_results.append( { - "snippets": snippets, - "title": organic.get("title"), - "url": organic.get("link"), - "description": knowledge_graph.get("description"), + 'snippets': snippets, + 'title': organic.get('title'), + 'url': organic.get('link'), + 'description': knowledge_graph.get('description'), } ) else: + # Common for knowledge graph to be None, set description to empty string collected_results.append( { - "snippets": snippets, - "title": result.get("title"), - "url": result.get("link"), - "description": "", + 'snippets': snippets, + 'title': result.get('title'), + 'url': result.get('link'), + 'description': '', } ) except: continue - return collected_results + return collected_results \ No newline at end of file From 985b3faac3a5ec62185a2d0dd816351635ecded9 Mon Sep 17 00:00:00 2001 From: Yijia Shao <67158122+shaoyijia@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:22:31 +0800 Subject: [PATCH 7/7] Update comments for SerperRM. --- knowledge_storm/rm.py | 39 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index 3d6092dd..0951ef33 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -406,29 +406,27 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st return collected_results class SerperRM(dspy.Retrieve): - """Retrieve information from custom queries using Serper.dev. - """ - - """ - Args: - serper_search_api_key str: Api key to run serper, can be found by creating an account on https://serper.dev/ - query_params (dict or list of dict): paramaters in dictionary or list of dictionaries that has a max size of 100 that will be used to query. - values: + """Retrieve information from custom queries using Serper.dev.""" + + def __init__(self, serper_search_api_key=None, query_params=None): + """Args: + serper_search_api_key str: API key to run serper, can be found by creating an account on https://serper.dev/ + query_params (dict or list of dict): parameters in dictionary or list of dictionaries that has a max size of 100 that will be used to query. + Commonly used fields are as follows (see more information in https://serper.dev/playground): q str: query that will be used with google search - type str: type that will be used for browsing google. Types are search, images, video, maps, places and many more. Refer to the playground for the types. + type str: type that will be used for browsing google. Types are search, images, video, maps, places, etc. gl str: Country that will be focused on for the search location str: Country where the search will originate from. All locates can be found here: https://api.serper.dev/locations. autocorrect bool: Enable autocorrect on the queries while searching, if query is misspelled, will be updated. results int: Max number of results per page. page int: Max number of pages per call. - tbs str: date time range, automically set to any time by default. - qdr:h str: Date time range for past hour. - qdr:d str: Date time range for past 24 hours. - qdr:w str: Date time range for past week. - qdr:m str: Date time range for past month. - qdr:y str: Date time range for past year. - """ - def __init__(self, serper_search_api_key=None, query_params=None): + tbs str: date time range, automatically set to any time by default. + qdr:h str: Date time range for the past hour. + qdr:d str: Date time range for the past 24 hours. + qdr:w str: Date time range for past week. + qdr:m str: Date time range for past month. + qdr:y str: Date time range for past year. + """ super().__init__() self.usage = 0 self.query_params = query_params @@ -495,9 +493,8 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st if query == 'Queries:': continue query_params = self.query_params - """ - All paramaters can be found in the playground: https://serper.dev/playground - """ + + # All available parameters can be found in the playground: https://serper.dev/playground # Sets the json value for query to be the query that is being parsed. query_params['q'] = query @@ -541,4 +538,4 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st except: continue - return collected_results \ No newline at end of file + return collected_results