Skip to content

Commit

Permalink
Merge pull request #1010 from v1xingyue/patch-custom-fetch
Browse files Browse the repository at this point in the history
feat: Add custom fetch logic for agent
  • Loading branch information
odilitime authored Dec 12, 2024
2 parents 5f266f1 + 0128a1e commit 4d5d680
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 43 deletions.
7 changes: 7 additions & 0 deletions agent/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ export const wait = (minTime: number = 1000, maxTime: number = 3000) => {
return new Promise((resolve) => setTimeout(resolve, waitTime));
};

const logFetch = async (url: string, options: any) => {
elizaLogger.info(`Fetching ${url}`);
elizaLogger.info(options);
return fetch(url, options);
};

export function parseArguments(): {
character?: string;
characters?: string;
Expand Down Expand Up @@ -473,6 +479,7 @@ export async function createAgent(
services: [],
managers: [],
cacheManager: cache,
fetch: logFetch,
});
}

Expand Down
138 changes: 95 additions & 43 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,47 +80,68 @@ export async function generateText({

// allow character.json settings => secrets to override models
// FIXME: add MODEL_MEDIUM support
switch(provider) {
switch (provider) {
// if runtime.getSetting("LLAMACLOUD_MODEL_LARGE") is true and modelProvider is LLAMACLOUD, then use the large model
case ModelProviderName.LLAMACLOUD: {
switch(modelClass) {
case ModelClass.LARGE: {
model = runtime.getSetting("LLAMACLOUD_MODEL_LARGE") || model;
}
break;
case ModelClass.SMALL: {
model = runtime.getSetting("LLAMACLOUD_MODEL_SMALL") || model;
case ModelProviderName.LLAMACLOUD:
{
switch (modelClass) {
case ModelClass.LARGE:
{
model =
runtime.getSetting("LLAMACLOUD_MODEL_LARGE") ||
model;
}
break;
case ModelClass.SMALL:
{
model =
runtime.getSetting("LLAMACLOUD_MODEL_SMALL") ||
model;
}
break;
}
break;
}
}
break;
case ModelProviderName.TOGETHER: {
switch(modelClass) {
case ModelClass.LARGE: {
model = runtime.getSetting("TOGETHER_MODEL_LARGE") || model;
}
break;
case ModelClass.SMALL: {
model = runtime.getSetting("TOGETHER_MODEL_SMALL") || model;
break;
case ModelProviderName.TOGETHER:
{
switch (modelClass) {
case ModelClass.LARGE:
{
model =
runtime.getSetting("TOGETHER_MODEL_LARGE") ||
model;
}
break;
case ModelClass.SMALL:
{
model =
runtime.getSetting("TOGETHER_MODEL_SMALL") ||
model;
}
break;
}
break;
}
}
break;
case ModelProviderName.OPENROUTER: {
switch(modelClass) {
case ModelClass.LARGE: {
model = runtime.getSetting("LARGE_OPENROUTER_MODEL") || model;
}
break;
case ModelClass.SMALL: {
model = runtime.getSetting("SMALL_OPENROUTER_MODEL") || model;
break;
case ModelProviderName.OPENROUTER:
{
switch (modelClass) {
case ModelClass.LARGE:
{
model =
runtime.getSetting("LARGE_OPENROUTER_MODEL") ||
model;
}
break;
case ModelClass.SMALL:
{
model =
runtime.getSetting("SMALL_OPENROUTER_MODEL") ||
model;
}
break;
}
break;
}
}
break;
break;
}

elizaLogger.info("Selected model:", model);
Expand Down Expand Up @@ -157,7 +178,11 @@ export async function generateText({
case ModelProviderName.HYPERBOLIC:
case ModelProviderName.TOGETHER: {
elizaLogger.debug("Initializing OpenAI model.");
const openai = createOpenAI({ apiKey, baseURL: endpoint });
const openai = createOpenAI({
apiKey,
baseURL: endpoint,
fetch: runtime.fetch,
});

const { text: openaiResponse } = await aiGenerateText({
model: openai.languageModel(model),
Expand All @@ -178,7 +203,9 @@ export async function generateText({
}

case ModelProviderName.GOOGLE: {
const google = createGoogleGenerativeAI();
const google = createGoogleGenerativeAI({
fetch: runtime.fetch,
});

const { text: googleResponse } = await aiGenerateText({
model: google(model),
Expand All @@ -201,7 +228,10 @@ export async function generateText({
case ModelProviderName.ANTHROPIC: {
elizaLogger.debug("Initializing Anthropic model.");

const anthropic = createAnthropic({ apiKey });
const anthropic = createAnthropic({
apiKey,
fetch: runtime.fetch,
});

const { text: anthropicResponse } = await aiGenerateText({
model: anthropic.languageModel(model),
Expand All @@ -224,7 +254,10 @@ export async function generateText({
case ModelProviderName.CLAUDE_VERTEX: {
elizaLogger.debug("Initializing Claude Vertex model.");

const anthropic = createAnthropic({ apiKey });
const anthropic = createAnthropic({
apiKey,
fetch: runtime.fetch,
});

const { text: anthropicResponse } = await aiGenerateText({
model: anthropic.languageModel(model),
Expand All @@ -248,7 +281,11 @@ export async function generateText({

case ModelProviderName.GROK: {
elizaLogger.debug("Initializing Grok model.");
const grok = createOpenAI({ apiKey, baseURL: endpoint });
const grok = createOpenAI({
apiKey,
baseURL: endpoint,
fetch: runtime.fetch,
});

const { text: grokResponse } = await aiGenerateText({
model: grok.languageModel(model, {
Expand All @@ -271,7 +308,7 @@ export async function generateText({
}

case ModelProviderName.GROQ: {
const groq = createGroq({ apiKey });
const groq = createGroq({ apiKey, fetch: runtime.fetch });

const { text: groqResponse } = await aiGenerateText({
model: groq.languageModel(model),
Expand Down Expand Up @@ -318,7 +355,11 @@ export async function generateText({
case ModelProviderName.REDPILL: {
elizaLogger.debug("Initializing RedPill model.");
const serverUrl = models[provider].endpoint;
const openai = createOpenAI({ apiKey, baseURL: serverUrl });
const openai = createOpenAI({
apiKey,
baseURL: serverUrl,
fetch: runtime.fetch,
});

const { text: redpillResponse } = await aiGenerateText({
model: openai.languageModel(model),
Expand All @@ -341,7 +382,11 @@ export async function generateText({
case ModelProviderName.OPENROUTER: {
elizaLogger.debug("Initializing OpenRouter model.");
const serverUrl = models[provider].endpoint;
const openrouter = createOpenAI({ apiKey, baseURL: serverUrl });
const openrouter = createOpenAI({
apiKey,
baseURL: serverUrl,
fetch: runtime.fetch,
});

const { text: openrouterResponse } = await aiGenerateText({
model: openrouter.languageModel(model),
Expand All @@ -367,6 +412,7 @@ export async function generateText({

const ollamaProvider = createOllama({
baseURL: models[provider].endpoint + "/api",
fetch: runtime.fetch,
});
const ollama = ollamaProvider(model);

Expand All @@ -391,6 +437,7 @@ export async function generateText({
const heurist = createOpenAI({
apiKey: apiKey,
baseURL: endpoint,
fetch: runtime.fetch,
});

const { text: heuristResponse } = await aiGenerateText({
Expand Down Expand Up @@ -436,7 +483,11 @@ export async function generateText({

elizaLogger.debug("Using GAIANET model with baseURL:", baseURL);

const openai = createOpenAI({ apiKey, baseURL: endpoint });
const openai = createOpenAI({
apiKey,
baseURL: endpoint,
fetch: runtime.fetch,
});

const { text: openaiResponse } = await aiGenerateText({
model: openai.languageModel(model),
Expand All @@ -461,6 +512,7 @@ export async function generateText({
const galadriel = createOpenAI({
apiKey: apiKey,
baseURL: endpoint,
fetch: runtime.fetch,
});

const { text: galadrielResponse } = await aiGenerateText({
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,8 @@ export interface IAgentRuntime {
evaluators: Evaluator[];
plugins: Plugin[];

fetch?: typeof fetch | null;

messageManager: IMemoryManager;
descriptionManager: IMemoryManager;
documentsManager: IMemoryManager;
Expand Down

0 comments on commit 4d5d680

Please sign in to comment.