-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Validation for Invalid task_type
in PEFT Configurations
#2210
Add Validation for Invalid task_type
in PEFT Configurations
#2210
Conversation
This reverts commit a86d6f8.
@BenjaminBossan What do you think of this lean implementation? |
@d-kleine Thanks for the PR. Right now, I'm at a company offsite and don't have the means to do proper reviews. I'll probably review at the start of next week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. Looking at this implementation, I think we need to adjust the logic a bit. Right now, if the user does not explicitly pass task_type
, it will be None
. Since None
is not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING
, users may now get an error even though their code might work just fine.
I think the best way to check this is instead: If task_type
is not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING
and if task_type
is not None
, we know that the user provided a value but that value is misspelled. Let's only raise an error in that case.
I would also suggest to check this in a different place. Right now, this is checked inside of get_peft_model
. But how about checking it already when the config is created? That way, we can give the feedback as early as possible. I would thus suggest to move the check to peft.config.PeftConfig.__post_init__
. In addition to that, I think we need to ensure that all configs that inherit, i.e. LoraConfig
, PrefixTuningConfig
, etc. all call super().__post_init__()
in their own __post_init__
methods. It's a bit annoying to have to change this in each config, but I think it's worth it.
Let me know your thoughts overall.
I see what you mean. I have implemented the changes, they work good so far, returning an error when setting up the config with either no About the parent: @dataclass
class PeftConfig(PeftConfigMixin):
... # no `def __post_init__(self):` here child: @dataclass
class LoraConfig(PeftConfig):
...
def __post_init__(self):
super().__post_init__() # error as not defined in parent So, what to do? Something like this? ⬇️ def __post_init__(self):
if hasattr(super(), '__post_init__'):
super().__post_init__() |
Hmm, I think this check for the presence of # peft/config.py in PeftConfigMixin
def __post_init__(self):
if (self.task_type is not None) and (self.task_type not in list(TaskType)):
raise ValueError(f"Invalid task type: '{self.task_type}'. Must be one of the following task types: {', '.join(TaskType)}.")
# peft/tuners/lora/config.py
def __post_init__(self):
super().__post_init__()
... # rest of the code Then running: from peft import LoraConfig
lora_config_1 = LoraConfig(task_type=None)
lora_config_1 = LoraConfig(task_type="CAUSAL_LM")
lora_config_2 = LoraConfig(task_type="CAUSAL_LMX") I get the expected error on the last config:
|
I was trying to say that there is almost no |
Yes, exactly, sorry for the confusion. |
Alright, thank you! Thanks to your sample code, I now understand that I just have implemented the logic you have suggested and pushed the changes. I have also noticed that you can import a from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import TaskType
# Step 1: Load the base model and tokenizer
model_name_or_path = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Step 2: Define the LoRA configuration
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type=TaskType.BLABLA # Invalid task type
) This would raise an AttributeError: AttributeError: BLABLA Even though it's not as specific as above, I think this is fine. What do you think? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, overall this looks good.
Before merging, I would really like to see a unit test added for this new check. For a start, let's take this test:
Lines 88 to 90 in 8874ab5
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) | |
def test_task_type(self, config_class): | |
config_class(task_type="test") |
It should now raise an error if I'm not mistaken. So let's pass a valid task type here. Next, let's create a similar test with an invalid task type and check that the error you added is raised.
I have not worked without defining a task_type yet, so are you sure this is a good idea not returning an error (or at least a warning) in this case?
The task types are very specific to the transformers definition of these tasks. However, PEFT is a general framework that is also applicable to non-transformers models (not everything but most). Therefore, it's totally fine not to supply a task type.
I have also noticed that you can important a TaskType directly
Yes, that's the nice thing about using enum
s, you will immediately know if you made a typo and also your text editor will help you with auto-complete. I'd say using this is the "proper" way of doing it but it requires a bit more typing so users are often lazy and just enter the string (I'm guilty as well).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding the tests, they're almost perfect, I just have a small suggestion to simplify them.
Yeah, I agree. I am using VS Code and the direct TaskType import is indeed better:
Great that the enum works as intended. Still, having this extra check for strings will be helpful overall.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Yeah, I fully agree 🙂 |
@d-kleine Thanks for the update, please run |
Done |
@d-kleine There a test is now failing which was indirectly caused by the changes of this PR. Long story short, please go to this line: peft/tests/test_initialization.py Line 1231 in d9aa089
and change it to |
Fixed and pushed. I was looking into the error too (it seems that the quantized state and its adapter initialization must be aligned for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding the task type validation, the PR LGTM.
Thanks for the merge! |
Fixes #2203
This PR adds validation for the
task_type
parameter across all PEFT types (e.g., LoRA, Prefix Tuning, etc.) to ensure that only valid task types, as defined in theTaskType
enum, are used. This resolves the issue where invalid or non-sense task types (e.g.,"BLABLA"
) or typos (e.g.,"CUASAL_LM"
instead of"CAUSAL_LM"
) could be passed without raising an error, potentially causing unexpected behavior.Why This Change Is Necessary:
task_type
, maintaining consistency across different PEFT configurations.Tests:
Example Scenario:
Before this PR, the following code would not raise an error despite using an invalid task type:
With this PR, attempting to use an invalid task type like
"BLABLA"
will now raise aValueError
, ensuring that only valid task types are accepted:For multiple PEFT methods:
INVALID
task_type
VALID
task_type