-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dspy.Embedding
#1735
Add dspy.Embedding
#1735
Conversation
d33a487
to
7c51351
Compare
litellm.telemetry = False | ||
|
||
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs to be done before LiteLLM is imported anywhere in DSPy, for it to have an effect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I searched their code, and this env var is read at runtime: https://github.com/BerriAI/litellm/blob/5652c375b3e22bab6704e93058c868620c72d6ee/litellm/__init__.py#L309, so our current order should be okay.
dspy/clients/embedding.py
Outdated
kwargs: Additional keyword arguments to pass to the embedding model. | ||
|
||
Returns: | ||
A list of embeddings, one for each input, in the same order as the inputs. Or the output of the custom |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we ensure the output of this is a numpy tensor or something? Both for litellm and for callables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
An ideal version of this PR would involve improving the docs at this page: https://dspy-docs.vercel.app/quick-start/getting-started-02/ (see the second cell, or see below) import torch
import functools
from litellm import embedding as Embed
with open("test_collection.jsonl") as f:
corpus = [ujson.loads(line) for line in f]
index = torch.load('index.pt', weights_only=True)
max_characters = 4000 # >98th percentile of document lengths
@functools.lru_cache(maxsize=None)
def search(query, k=5):
query_embedding = torch.tensor(Embed(input=query, model="text-embedding-3-small").data[0]['embedding'])
topk_scores, topk_indices = torch.matmul(index, query_embedding).topk(k)
topK = [dict(score=score.item(), **corpus[idx]) for idx, score in zip(topk_indices, topk_scores)]
return [doc['text'][:max_characters] for doc in topK] I'd love to get the same functionality but without that complexity... |
76f162b
to
40bfd8b
Compare
Very simple
dspy.Embedding
supports:dspy.Embedding
, and the output is just the custom callable's output.Added unit test for both scenarios.
Confirmed that the cache works: