85 lines
2.7 KiB
Python
Raw Normal View History

2024-01-22 23:46:27 +08:00
import inspect
from fastapi import Form, Query
2024-01-22 23:46:27 +08:00
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from typing import Type, TypeVar
2024-01-22 23:46:27 +08:00
BaseModelVar = TypeVar('BaseModelVar', bound=BaseModel)
def as_query(cls: Type[BaseModelVar]) -> Type[BaseModelVar]:
2024-01-22 23:46:27 +08:00
"""
pydantic模型查询参数装饰器将pydantic模型用于接收查询参数
"""
new_parameters = []
for field_name, model_field in cls.model_fields.items():
model_field: FieldInfo # type: ignore
if not model_field.is_required():
new_parameters.append(
inspect.Parameter(
model_field.alias,
inspect.Parameter.POSITIONAL_ONLY,
default=Query(default=model_field.default, description=model_field.description),
annotation=model_field.annotation,
2024-01-22 23:46:27 +08:00
)
)
else:
new_parameters.append(
inspect.Parameter(
model_field.alias,
inspect.Parameter.POSITIONAL_ONLY,
default=Query(..., description=model_field.description),
annotation=model_field.annotation,
2024-01-22 23:46:27 +08:00
)
)
async def as_query_func(**data):
return cls(**data)
sig = inspect.signature(as_query_func)
sig = sig.replace(parameters=new_parameters)
as_query_func.__signature__ = sig # type: ignore
setattr(cls, 'as_query', as_query_func)
return cls
def as_form(cls: Type[BaseModelVar]) -> Type[BaseModelVar]:
2024-01-22 23:46:27 +08:00
"""
pydantic模型表单参数装饰器将pydantic模型用于接收表单参数
"""
new_parameters = []
for field_name, model_field in cls.model_fields.items():
model_field: FieldInfo # type: ignore
if not model_field.is_required():
new_parameters.append(
inspect.Parameter(
model_field.alias,
inspect.Parameter.POSITIONAL_ONLY,
default=Form(default=model_field.default, description=model_field.description),
annotation=model_field.annotation,
2024-01-22 23:46:27 +08:00
)
)
else:
new_parameters.append(
inspect.Parameter(
model_field.alias,
inspect.Parameter.POSITIONAL_ONLY,
default=Form(..., description=model_field.description),
annotation=model_field.annotation,
2024-01-22 23:46:27 +08:00
)
)
async def as_form_func(**data):
return cls(**data)
sig = inspect.signature(as_form_func)
sig = sig.replace(parameters=new_parameters)
as_form_func.__signature__ = sig # type: ignore
setattr(cls, 'as_form', as_form_func)
return cls