Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Validation for Invalid task_type in PEFT Configurations #2210

Merged
merged 11 commits into from
Nov 21, 2024

Conversation

d-kleine
Copy link
Contributor

@d-kleine d-kleine commented Nov 11, 2024

Fixes #2203

This PR adds validation for the task_type parameter across all PEFT types (e.g., LoRA, Prefix Tuning, etc.) to ensure that only valid task types, as defined in the TaskType enum, are used. This resolves the issue where invalid or non-sense task types (e.g., "BLABLA") or typos (e.g., "CUASAL_LM" instead of "CAUSAL_LM") could be passed without raising an error, potentially causing unexpected behavior.

Why This Change Is Necessary:

  • Prevents Silent Failures: Without this validation, users could pass incorrect or unsupported task types without receiving any feedback. This could lead to silent failures or unexpected behavior during training or inference.
  • Consistency: Ensures that all PEFT types adhere to the same validation rules for task_type, maintaining consistency across different PEFT configurations.

Tests:

  • Ensured that invalid task types raise appropriate errors.
  • Verified that valid task types continue to work as expected.

Example Scenario:

Before this PR, the following code would not raise an error despite using an invalid task type:

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

# Step 1: Load the base model and tokenizer
model_name_or_path = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Step 2: Define the LoRA configuration
lora_config = LoraConfig(
    r=16,  
    lora_alpha=32, 
    lora_dropout=0.1,  
    target_modules=["q_proj", "v_proj"],  
    bias="none",  
    task_type="BLABLA"  # Invalid task type 
)

# Step 3: Wrap the base model with LoRA using the configuration
peft_model = get_peft_model(model, lora_config)

With this PR, attempting to use an invalid task type like "BLABLA" will now raise a ValueError, ensuring that only valid task types are accepted:

ValueError: Invalid task type: 'BLABLA'. Must be one of the following task types: ['SEQ_CLS', 'SEQ_2_SEQ_LM', 'CAUSAL_LM', 'TOKEN_CLS', 'QUESTION_ANS', 'FEATURE_EXTRACTION'].

For multiple PEFT methods:

INVALID task_type

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import (
    LoraConfig, 
    PromptTuningConfig, 
    PrefixTuningConfig,
    get_peft_model
)

# Step 1: Load the base model and tokenizer
model_name_or_path = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Step 2: Define configurations for each PEFT method with an invalid task type
configs = [
    LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=["q_proj", "v_proj"],
        bias="none",
        task_type="BLABLA"  # Invalid task type
    ),
    PromptTuningConfig(
        num_virtual_tokens=20,
        task_type="BLABLA"  # Invalid task type
    ),
    PrefixTuningConfig(
        num_virtual_tokens=30,
        task_type="BLABLA"  # Invalid task type
    ),
    # Add more configurations for other PeftTypes if necessary...
]

# Step 3: Test each configuration and check if it raises a ValueError for invalid task type
for config in configs:
    try:
        print(f"Testing {config.__class__.__name__}...")
        peft_model = get_peft_model(model, config)
        print("OK")
    except ValueError as e:
        print(f"Error: {e}")
Testing LoraConfig...
Error: Invalid task type: 'BLABLA'. Must be one of the following task types: ['SEQ_CLS', 'SEQ_2_SEQ_LM', 'CAUSAL_LM', 'TOKEN_CLS', 'QUESTION_ANS', 'FEATURE_EXTRACTION'].
Testing PromptTuningConfig...
Error: Invalid task type: 'BLABLA'. Must be one of the following task types: ['SEQ_CLS', 'SEQ_2_SEQ_LM', 'CAUSAL_LM', 'TOKEN_CLS', 'QUESTION_ANS', 'FEATURE_EXTRACTION'].
Testing PrefixTuningConfig...
Error: Invalid task type: 'BLABLA'. Must be one of the following task types: ['SEQ_CLS', 'SEQ_2_SEQ_LM', 'CAUSAL_LM', 'TOKEN_CLS', 'QUESTION_ANS', 'FEATURE_EXTRACTION'].

VALID task_type

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import (
    LoraConfig, 
    PromptTuningConfig, 
    PrefixTuningConfig,
    get_peft_model
)

# Step 1: Load the base model and tokenizer
model_name_or_path = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Step 2: Define configurations for each PEFT method with an valid task type
configs = [
    LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=["q_proj", "v_proj"],
        bias="none",
        task_type="CAUSAL_LM"  # Valid task type
    ),
    PromptTuningConfig(
        num_virtual_tokens=20,
        task_type="CAUSAL_LM"  # Valid task type
    ),
    PrefixTuningConfig(
        num_virtual_tokens=30,
        task_type="CAUSAL_LM"  # Valid task type
    ),
]

# Step 3: Test each configuration and check if it raises a ValueError for invalid task type
for config in configs:
    try:
        print(f"Testing {config.__class__.__name__}...")
        peft_model = get_peft_model(model, config)
        print("OK")
    except ValueError as e:
        print(f"Error: {e}")
Testing LoraConfig...
OK
Testing PromptTuningConfig...
OK
Testing PrefixTuningConfig...
OK

@d-kleine d-kleine marked this pull request as ready for review November 11, 2024 09:45
@d-kleine
Copy link
Contributor Author

@BenjaminBossan What do you think of this lean implementation?

@BenjaminBossan
Copy link
Member

@d-kleine Thanks for the PR. Right now, I'm at a company offsite and don't have the means to do proper reviews. I'll probably review at the start of next week.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Looking at this implementation, I think we need to adjust the logic a bit. Right now, if the user does not explicitly pass task_type, it will be None. Since None is not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING, users may now get an error even though their code might work just fine.

I think the best way to check this is instead: If task_type is not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING and if task_type is not None, we know that the user provided a value but that value is misspelled. Let's only raise an error in that case.

I would also suggest to check this in a different place. Right now, this is checked inside of get_peft_model. But how about checking it already when the config is created? That way, we can give the feedback as early as possible. I would thus suggest to move the check to peft.config.PeftConfig.__post_init__. In addition to that, I think we need to ensure that all configs that inherit, i.e. LoraConfig, PrefixTuningConfig, etc. all call super().__post_init__() in their own __post_init__ methods. It's a bit annoying to have to change this in each config, but I think it's worth it.

Let me know your thoughts overall.

@d-kleine
Copy link
Contributor Author

d-kleine commented Nov 18, 2024

I see what you mean. I have implemented the changes, they work good so far, returning an error when setting up the config with either no task_type is passed (None) or the one passed does not match with the task types defined in the mappings.

About the super().__post_init__(), I agree this is also a good idea for this project in general. But I have noticed most parents (e.g. PeftConfig, PromptLearningConfig) don't have a __post_init__() method yet, thus calling super().__post_init__() in the def __post_init__(self): in would result in an error message then

parent:

@dataclass
class PeftConfig(PeftConfigMixin):
    ... # no `def __post_init__(self):` here

child:

@dataclass
class LoraConfig(PeftConfig):
...
    def __post_init__(self):
        super().__post_init__() # error as not defined in parent

So, what to do? Something like this? ⬇️

def __post_init__(self):
    if hasattr(super(), '__post_init__'):
        super().__post_init__()

@BenjaminBossan
Copy link
Member

Hmm, I think this check for the presence of __post_init__ should not be necessary. I tried this locally:

# peft/config.py in PeftConfigMixin
    def __post_init__(self):
        if (self.task_type is not None) and (self.task_type not in list(TaskType)):
            raise ValueError(f"Invalid task type: '{self.task_type}'. Must be one of the following task types: {', '.join(TaskType)}.")

# peft/tuners/lora/config.py
    def __post_init__(self):
        super().__post_init__()

        ...  # rest of the code

Then running:

from peft import LoraConfig

lora_config_1 = LoraConfig(task_type=None)
lora_config_1 = LoraConfig(task_type="CAUSAL_LM")
lora_config_2 = LoraConfig(task_type="CAUSAL_LMX")

I get the expected error on the last config:

ValueError: Invalid task type: 'CAUSAL_LMX'. Must be one of the following task types: SEQ_CLS, SEQ_2_SEQ_LM, CAUSAL_LM, TOKEN_CLS, QUESTION_ANS, FEATURE_EXTRACTION.

@d-kleine
Copy link
Contributor Author

d-kleine commented Nov 19, 2024

I was trying to say that there is almost no __post_init__ in the parents yet, but now I understood that you want me to create a __post_init__ containing the task type check in PeftConfigMixin (which are all PEFT configs seem to be based on). Is that right?

@BenjaminBossan
Copy link
Member

you want me to create a __post_init__ containing the task type check in PeftConfigMixin (which are all PEFT configs seem to be based on). Is that right?

Yes, exactly, sorry for the confusion.

@d-kleine
Copy link
Contributor Author

d-kleine commented Nov 19, 2024

Alright, thank you!

Thanks to your sample code, I now understand that task_type can be None as a valid value. Initially, I thought setting task_type to None would be an invalid configuration and therefore should also raise a ValueError. I have not worked without defining a task_type yet, so are you sure this is a good idea not returning an error (or at least a warning) in this case?

I just have implemented the logic you have suggested and pushed the changes.


I have also noticed that you can import a TaskType directly, like

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import TaskType

# Step 1: Load the base model and tokenizer
model_name_or_path = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Step 2: Define the LoRA configuration
lora_config = LoraConfig(
    r=16,  
    lora_alpha=32, 
    lora_dropout=0.1,  
    target_modules=["q_proj", "v_proj"],  
    bias="none",  
    task_type=TaskType.BLABLA  # Invalid task type 
)

This would raise an AttributeError:

AttributeError: BLABLA

Even though it's not as specific as above, I think this is fine. What do you think?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes, overall this looks good.

Before merging, I would really like to see a unit test added for this new check. For a start, let's take this test:

@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_task_type(self, config_class):
config_class(task_type="test")

It should now raise an error if I'm not mistaken. So let's pass a valid task type here. Next, let's create a similar test with an invalid task type and check that the error you added is raised.

I have not worked without defining a task_type yet, so are you sure this is a good idea not returning an error (or at least a warning) in this case?

The task types are very specific to the transformers definition of these tasks. However, PEFT is a general framework that is also applicable to non-transformers models (not everything but most). Therefore, it's totally fine not to supply a task type.

I have also noticed that you can important a TaskType directly

Yes, that's the nice thing about using enums, you will immediately know if you made a typo and also your text editor will help you with auto-complete. I'd say using this is the "proper" way of doing it but it requires a bit more typing so users are often lazy and just enter the string (I'm guilty as well).

@d-kleine
Copy link
Contributor Author

d-kleine commented Nov 21, 2024

It should now raise an error if I'm not mistaken. So let's pass a valid task type here. Next, let's create a similar test with an invalid task type and check that the error you added is raised.

I have added the tests (one for valid task types, one for invalid task types), as requested. I have defined the valid task types as the ones defined in TaskType plus None. For the invalid task types, I have only used the provided example "test".

Furthermore, I have run the tests before (you were right, the tests would have failed). After implementing the changes in the tests, the unit tests will pass then:

peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-AdaLoraConfig] PASSED                                                                                                                                                                     [  0%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-AdaptionPromptConfig] PASSED                                                                                                                                                              [  1%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-BOFTConfig] PASSED                                                                                                                                                                        [  2%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-FourierFTConfig] PASSED                                                                                                                                                                   [  2%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-HRAConfig] PASSED                                                                                                                                                                         [  3%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-IA3Config] PASSED                                                                                                                                                                         [  4%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-LNTuningConfig] PASSED                                                                                                                                                                    [  5%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-LoHaConfig] PASSED                                                                                                                                                                        [  5%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-LoKrConfig] PASSED                                                                                                                                                                        [  6%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-LoraConfig] PASSED                                                                                                                                                                        [  7%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-MultitaskPromptTuningConfig] PASSED                                                                                                                                                       [  8%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-PolyConfig] PASSED                                                                                                                                                                        [  8%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-PrefixTuningConfig] PASSED                                                                                                                                                                [  9%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-PromptEncoderConfig] PASSED                                                                                                                                                               [ 10%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-PromptTuningConfig] PASSED                                                                                                                                                                [ 11%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-VeraConfig] PASSED                                                                                                                                                                        [ 11%] 
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-VBLoRAConfig] PASSED                                                                                                                                                                      [ 12%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-AdaLoraConfig] PASSED                                                                                                                                                                    [ 13%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-AdaptionPromptConfig] PASSED                                                                                                                                                             [ 13%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-BOFTConfig] PASSED                                                                                                                                                                       [ 14%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-FourierFTConfig] PASSED                                                                                                                                                                  [ 15%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-HRAConfig] PASSED                                                                                                                                                                        [ 16%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-IA3Config] PASSED                                                                                                                                                                        [ 16%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-LNTuningConfig] PASSED                                                                                                                                                                   [ 17%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-LoHaConfig] PASSED                                                                                                                                                                       [ 18%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-LoKrConfig] PASSED                                                                                                                                                                       [ 19%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-LoraConfig] PASSED                                                                                                                                                                       [ 19%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-MultitaskPromptTuningConfig] PASSED                                                                                                                                                      [ 20%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-PolyConfig] PASSED                                                                                                                                                                       [ 21%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-PrefixTuningConfig] PASSED                                                                                                                                                               [ 22%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-PromptEncoderConfig] PASSED                                                                                                                                                              [ 22%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-PromptTuningConfig] PASSED                                                                                                                                                               [ 23%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-VeraConfig] PASSED                                                                                                                                                                       [ 24%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-VBLoRAConfig] PASSED                                                                                                                                                                     [ 25%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-AdaLoraConfig] PASSED                                                                                                                                                               [ 25%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-AdaptionPromptConfig] PASSED                                                                                                                                                        [ 26%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-BOFTConfig] PASSED                                                                                                                                                                  [ 27%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-FourierFTConfig] PASSED                                                                                                                                                             [ 27%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-HRAConfig] PASSED                                                                                                                                                                   [ 28%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-IA3Config] PASSED                                                                                                                                                                   [ 29%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-LNTuningConfig] PASSED                                                                                                                                                              [ 30%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-LoHaConfig] PASSED                                                                                                                                                                  [ 30%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-LoKrConfig] PASSED                                                                                                                                                                  [ 31%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-LoraConfig] PASSED                                                                                                                                                                  [ 32%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-MultitaskPromptTuningConfig] PASSED                                                                                                                                                 [ 33%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-PolyConfig] PASSED                                                                                                                                                                  [ 33%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-PrefixTuningConfig] PASSED                                                                                                                                                          [ 34%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-PromptEncoderConfig] PASSED                                                                                                                                                         [ 35%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-PromptTuningConfig] PASSED                                                                                                                                                          [ 36%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-VeraConfig] PASSED                                                                                                                                                                  [ 36%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-VBLoRAConfig] PASSED                                                                                                                                                                [ 37%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-AdaLoraConfig] PASSED                                                                                                                                                                  [ 38%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-AdaptionPromptConfig] PASSED                                                                                                                                                           [ 38%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-BOFTConfig] PASSED                                                                                                                                                                     [ 39%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-FourierFTConfig] PASSED                                                                                                                                                                [ 40%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-HRAConfig] PASSED                                                                                                                                                                      [ 41%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-IA3Config] PASSED                                                                                                                                                                      [ 41%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-LNTuningConfig] PASSED                                                                                                                                                                 [ 42%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-LoHaConfig] PASSED                                                                                                                                                                     [ 43%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-LoKrConfig] PASSED                                                                                                                                                                     [ 44%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-LoraConfig] PASSED                                                                                                                                                                     [ 44%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-MultitaskPromptTuningConfig] PASSED                                                                                                                                                    [ 45%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-PolyConfig] PASSED                                                                                                                                                                     [ 46%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-PrefixTuningConfig] PASSED                                                                                                                                                             [ 47%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-PromptEncoderConfig] PASSED                                                                                                                                                            [ 47%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-PromptTuningConfig] PASSED                                                                                                                                                             [ 48%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-VeraConfig] PASSED                                                                                                                                                                     [ 49%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-VBLoRAConfig] PASSED                                                                                                                                                                   [ 50%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-AdaLoraConfig] PASSED                                                                                                                                                                  [ 50%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-AdaptionPromptConfig] PASSED                                                                                                                                                           [ 51%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-BOFTConfig] PASSED                                                                                                                                                                     [ 52%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-FourierFTConfig] PASSED                                                                                                                                                                [ 52%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-HRAConfig] PASSED                                                                                                                                                                      [ 53%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-IA3Config] PASSED                                                                                                                                                                      [ 54%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-LNTuningConfig] PASSED                                                                                                                                                                 [ 55%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-LoHaConfig] PASSED                                                                                                                                                                     [ 55%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-LoKrConfig] PASSED                                                                                                                                                                     [ 56%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-LoraConfig] PASSED                                                                                                                                                                     [ 57%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-MultitaskPromptTuningConfig] PASSED                                                                                                                                                    [ 58%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-PolyConfig] PASSED                                                                                                                                                                     [ 58%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-PrefixTuningConfig] PASSED                                                                                                                                                             [ 59%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-PromptEncoderConfig] PASSED                                                                                                                                                            [ 60%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-PromptTuningConfig] PASSED                                                                                                                                                             [ 61%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-VeraConfig] PASSED                                                                                                                                                                     [ 61%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-VBLoRAConfig] PASSED                                                                                                                                                                   [ 62%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-AdaLoraConfig] PASSED                                                                                                                                                               [ 63%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-AdaptionPromptConfig] PASSED                                                                                                                                                        [ 63%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-BOFTConfig] PASSED                                                                                                                                                                  [ 64%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-FourierFTConfig] PASSED                                                                                                                                                             [ 65%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-HRAConfig] PASSED                                                                                                                                                                   [ 66%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-IA3Config] PASSED                                                                                                                                                                   [ 66%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-LNTuningConfig] PASSED                                                                                                                                                              [ 67%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-LoHaConfig] PASSED                                                                                                                                                                  [ 68%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-LoKrConfig] PASSED                                                                                                                                                                  [ 69%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-LoraConfig] PASSED                                                                                                                                                                  [ 69%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-MultitaskPromptTuningConfig] PASSED                                                                                                                                                 [ 70%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-PolyConfig] PASSED                                                                                                                                                                  [ 71%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-PrefixTuningConfig] PASSED                                                                                                                                                          [ 72%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-PromptEncoderConfig] PASSED                                                                                                                                                         [ 72%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-PromptTuningConfig] PASSED                                                                                                                                                          [ 73%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-VeraConfig] PASSED                                                                                                                                                                  [ 74%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-VBLoRAConfig] PASSED                                                                                                                                                                [ 75%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-AdaLoraConfig] PASSED                                                                                                                                                         [ 75%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-AdaptionPromptConfig] PASSED                                                                                                                                                  [ 76%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-BOFTConfig] PASSED                                                                                                                                                            [ 77%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-FourierFTConfig] PASSED                                                                                                                                                       [ 77%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-HRAConfig] PASSED                                                                                                                                                             [ 78%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-IA3Config] PASSED                                                                                                                                                             [ 79%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-LNTuningConfig] PASSED                                                                                                                                                        [ 80%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-LoHaConfig] PASSED                                                                                                                                                            [ 80%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-LoKrConfig] PASSED                                                                                                                                                            [ 81%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-LoraConfig] PASSED                                                                                                                                                            [ 82%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-MultitaskPromptTuningConfig] PASSED                                                                                                                                           [ 83%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-PolyConfig] PASSED                                                                                                                                                            [ 83%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-PrefixTuningConfig] PASSED                                                                                                                                                    [ 84%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-PromptEncoderConfig] PASSED                                                                                                                                                   [ 85%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-PromptTuningConfig] PASSED                                                                                                                                                    [ 86%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-VeraConfig] PASSED                                                                                                                                                            [ 86%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-VBLoRAConfig] PASSED                                                                                                                                                          [ 87%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-AdaLoraConfig] PASSED                                                                                                                                                                       [ 88%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-AdaptionPromptConfig] PASSED                                                                                                                                                                [ 88%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-BOFTConfig] PASSED                                                                                                                                                                          [ 89%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-FourierFTConfig] PASSED                                                                                                                                                                     [ 90%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-HRAConfig] PASSED                                                                                                                                                                           [ 91%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-IA3Config] PASSED                                                                                                                                                                           [ 91%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-LNTuningConfig] PASSED                                                                                                                                                                      [ 92%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-LoHaConfig] PASSED                                                                                                                                                                          [ 93%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-LoKrConfig] PASSED                                                                                                                                                                          [ 94%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-LoraConfig] PASSED                                                                                                                                                                          [ 94%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-MultitaskPromptTuningConfig] PASSED                                                                                                                                                         [ 95%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-PolyConfig] PASSED                                                                                                                                                                          [ 96%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-PrefixTuningConfig] PASSED                                                                                                                                                                  [ 97%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-PromptEncoderConfig] PASSED                                                                                                                                                                 [ 97%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-PromptTuningConfig] PASSED                                                                                                                                                                  [ 98%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-VeraConfig] PASSED                                                                                                                                                                          [ 99%] 
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-VBLoRAConfig] PASSED                                                                                                                                                                        [100%]

All tests in the script will also pass:
grafik

If you need further changes / improvements, please let me know.

Yes, that's the nice thing about using enums, you will immediately know if you made a typo and also your text editor will help you with auto-complete. I'd say using this is the "proper" way of doing it but it requires a bit more typing so users are often lazy and just enter the string (I'm guilty as well).

Yeah, I agree. I am using VS Code and the direct TaskType import is indeed better:

grafik

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding the tests, they're almost perfect, I just have a small suggestion to simplify them.

Yeah, I agree. I am using VS Code and the direct TaskType import is indeed better:

Great that the enum works as intended. Still, having this extra check for strings will be helpful overall.

tests/test_config.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@d-kleine
Copy link
Contributor Author

Great that the enum works as intended. Still, having this extra check for strings will be helpful overall.

Yeah, I fully agree 🙂

@BenjaminBossan
Copy link
Member

@d-kleine Thanks for the update, please run make style too.

@d-kleine
Copy link
Contributor Author

@d-kleine Thanks for the update, please run make style too.

Done

@BenjaminBossan
Copy link
Member

@d-kleine There a test is now failing which was indirectly caused by the changes of this PR. Long story short, please go to this line:

AdaLoraConfig(loftq_config={"loftq": "config"})

and change it to AdaLoraConfig(init_lora_weights="loftq", loftq_config={"loftq": "config"}). LMK if you want me to explain why it's needed, but it's really just tangential to your PR.

@d-kleine
Copy link
Contributor Author

Fixed and pushed. I was looking into the error too (it seems that the quantized state and its adapter initialization must be aligned for AdaLoraConfig). But did this happen because of my PR?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding the task type validation, the PR LGTM.

@BenjaminBossan BenjaminBossan merged commit 0443734 into huggingface:main Nov 21, 2024
14 checks passed
@d-kleine
Copy link
Contributor Author

d-kleine commented Nov 21, 2024

Thanks for the merge!

@d-kleine d-kleine deleted the task_type_assertion branch November 30, 2024 14:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Assertions for task_type in LoraConfig
3 participants