adding typing and flake8

This commit is contained in:
Peter Dwyer
2022-09-01 11:51:42 +01:00
parent 64d75115ca
commit e098de3122
3 changed files with 26 additions and 21 deletions

View File

@@ -19,3 +19,6 @@ repos:
rev: "5.0.4" rev: "5.0.4"
hooks: hooks:
- id: flake8 - id: flake8
additional_dependencies: [
'flake8-annotations==2.9.1',
]

View File

@@ -1,12 +1,12 @@
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union, Optional, Dict, Iterable, List
from uuid import UUID from uuid import UUID
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.auth.password_validation import validate_password from django.contrib.auth.password_validation import validate_password
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db.models import Count, Case, When, F, PositiveSmallIntegerField from django.db.models import Count, Case, When, F, PositiveSmallIntegerField, FileField, QuerySet
from django.http import FileResponse from django.http import FileResponse
from drf_yasg.utils import swagger_auto_schema from drf_yasg.utils import swagger_auto_schema
from rest_framework import viewsets, serializers, mixins, permissions, status, renderers from rest_framework import viewsets, serializers, mixins, permissions, status, renderers
@@ -38,7 +38,7 @@ class AdminPasswordResetSerializer(serializers.Serializer):
class ClassificationSerializer(serializers.Serializer): class ClassificationSerializer(serializers.Serializer):
classification = serializers.IntegerField() classification = serializers.IntegerField()
def validate_classification(self, data): def validate_classification(self, data: int) -> int:
if data in models.Directory.Classification: if data in models.Directory.Classification:
return data return data
raise serializers.ValidationError('Invalid Classification sent.') raise serializers.ValidationError('Invalid Classification sent.')
@@ -86,7 +86,7 @@ class UserViewSet(viewsets.ModelViewSet): # pylint: disable=too-many-ancestors
class BrowseFileField(serializers.FileField): class BrowseFileField(serializers.FileField):
def to_representation(self, value): def to_representation(self, value: Optional[FileField]) -> Optional[str]:
if not value: if not value:
return None return None
return Path(settings.MEDIA_URL, value.name).as_posix() return Path(settings.MEDIA_URL, value.name).as_posix()
@@ -115,12 +115,12 @@ class BrowseViewSet(viewsets.GenericViewSet):
permission_classes = [permissions.IsAuthenticated] permission_classes = [permissions.IsAuthenticated]
lookup_field = 'selector' lookup_field = 'selector'
def list(self, request): def list(self, request: Request) -> Response:
serializer = self.get_serializer(generate_directory(request.user), many=True) serializer = self.get_serializer(generate_directory(request.user), many=True)
return Response(serializer.data) return Response(serializer.data)
@swagger_auto_schema(responses={status.HTTP_200_OK: BrowseSerializer(many=True)}) @swagger_auto_schema(responses={status.HTTP_200_OK: BrowseSerializer(many=True)})
def retrieve(self, request, selector: UUID): def retrieve(self, request: Request, selector: UUID) -> Response:
directory = models.Directory.objects.get(selector=selector) directory = models.Directory.objects.get(selector=selector)
serializer = self.get_serializer(generate_directory(request.user, directory), many=True) serializer = self.get_serializer(generate_directory(request.user, directory), many=True)
return Response(serializer.data) return Response(serializer.data)
@@ -157,7 +157,7 @@ class GenerateThumbnailViewSet(viewsets.ViewSet):
lookup_field = 'selector' lookup_field = 'selector'
@swagger_auto_schema(responses={status.HTTP_200_OK: GenerateThumbnailSerializer()}) @swagger_auto_schema(responses={status.HTTP_200_OK: GenerateThumbnailSerializer()})
def retrieve(self, _request, selector: UUID): def retrieve(self, _request: Request, selector: UUID) -> Response:
try: try:
directory = models.Directory.objects.get(selector=selector) directory = models.Directory.objects.get(selector=selector)
if not directory.thumbnail: if not directory.thumbnail:
@@ -290,7 +290,8 @@ class PassthroughRenderer(renderers.BaseRenderer): # pylint: disable=too-few-pu
media_type = '*/*' media_type = '*/*'
format = '' format = ''
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data: bytes, accepted_media_type: Optional[str] = None, renderer_context: Optional[str] = None) \
-> bytes:
return data return data
@@ -300,12 +301,11 @@ class ImageViewSet(viewsets.ViewSet):
renderer_classes = [PassthroughRenderer] renderer_classes = [PassthroughRenderer]
@swagger_auto_schema(responses={status.HTTP_200_OK: "A Binary Image response"}) @swagger_auto_schema(responses={status.HTTP_200_OK: "A Binary Image response"})
def retrieve(self, _request, parent_lookup_selector, page): def retrieve(self, _request: Request, parent_lookup_selector: UUID, page: int) -> FileResponse:
book = models.ComicBook.objects.get(selector=parent_lookup_selector) book = models.ComicBook.objects.get(selector=parent_lookup_selector)
img, content = book.get_image(int(page)) img, content = book.get_image(int(page))
self.renderer_classes[0].media_type = content self.renderer_classes[0].media_type = content
response = FileResponse(img, content_type=content) return FileResponse(img, content_type=content)
return response
class StandardResultsSetPagination(PageNumberPagination): class StandardResultsSetPagination(PageNumberPagination):
@@ -331,7 +331,7 @@ class RecentComicsView(mixins.ListModelMixin, viewsets.GenericViewSet):
pagination_class = StandardResultsSetPagination pagination_class = StandardResultsSetPagination
permission_classes = [permissions.IsAuthenticated] permission_classes = [permissions.IsAuthenticated]
def get_queryset(self): def get_queryset(self) -> QuerySet[models.ComicBook]:
user = self.request.user user = self.request.user
if "search_text" in self.request.query_params: if "search_text" in self.request.query_params:
query = models.ComicBook.objects.filter(file_name__icontains=self.request.query_params["search_text"]) query = models.ComicBook.objects.filter(file_name__icontains=self.request.query_params["search_text"])
@@ -364,7 +364,7 @@ class ActionViewSet(viewsets.GenericViewSet):
serializer_class = ActionSerializer serializer_class = ActionSerializer
@action(detail=False, methods=['PUT']) @action(detail=False, methods=['PUT'])
def mark_read(self, request): def mark_read(self, request: Request) -> Response:
serializer = ActionSerializer(data=request.data) serializer = ActionSerializer(data=request.data)
if serializer.is_valid(): if serializer.is_valid():
comics = self.get_comics(serializer.data['selectors']) comics = self.get_comics(serializer.data['selectors'])
@@ -391,7 +391,7 @@ class ActionViewSet(viewsets.GenericViewSet):
status=status.HTTP_400_BAD_REQUEST) status=status.HTTP_400_BAD_REQUEST)
@action(detail=False, methods=['PUT']) @action(detail=False, methods=['PUT'])
def mark_unread(self, request): def mark_unread(self, request: Request) -> Response:
serializer = ActionSerializer(data=request.data) serializer = ActionSerializer(data=request.data)
if serializer.is_valid(): if serializer.is_valid():
comics = self.get_comics(serializer.data['selectors']) comics = self.get_comics(serializer.data['selectors'])
@@ -415,7 +415,7 @@ class ActionViewSet(viewsets.GenericViewSet):
return Response(serializer.errors, return Response(serializer.errors,
status=status.HTTP_400_BAD_REQUEST) status=status.HTTP_400_BAD_REQUEST)
def get_comics(self, selectors): def get_comics(self, selectors: Iterable[str, UUID]) -> List[str]:
data = set() data = set()
data = data.union( data = data.union(
set(models.ComicBook.objects.filter(selector__in=selectors).values_list('selector', flat=True))) set(models.ComicBook.objects.filter(selector__in=selectors).values_list('selector', flat=True)))
@@ -451,7 +451,7 @@ class PasswordResetSerializer(serializers.Serializer):
new_password = serializers.CharField(required=False) new_password = serializers.CharField(required=False)
new_password_confirm = serializers.CharField(required=False) new_password_confirm = serializers.CharField(required=False)
def validate_new_password(self, data): def validate_new_password(self, data: str) -> str:
if data == '': if data == '':
return data return data
try: try:
@@ -460,7 +460,7 @@ class PasswordResetSerializer(serializers.Serializer):
raise serializers.ValidationError(err) raise serializers.ValidationError(err)
return data return data
def validate(self, attrs): def validate(self, attrs: Dict[str, str]) -> Dict[str, str]:
super().validate(attrs) super().validate(attrs)
if attrs['new_password'] != attrs['new_password_confirm']: if attrs['new_password'] != attrs['new_password_confirm']:
raise serializers.ValidationError('New passwords do not match') raise serializers.ValidationError('New passwords do not match')
@@ -486,7 +486,7 @@ class AccountViewSet(viewsets.GenericViewSet):
return Response(AccountSerializer(request.user).data) return Response(AccountSerializer(request.user).data)
return Response({"errors": serializer.errors}, status.HTTP_400_BAD_REQUEST) return Response({"errors": serializer.errors}, status.HTTP_400_BAD_REQUEST)
def list(self, request): def list(self, request: Request) -> Response:
serializer = self.get_serializer(request.user) serializer = self.get_serializer(request.user)
return Response(serializer.data) return Response(serializer.data)
@@ -507,7 +507,7 @@ class AccountViewSet(viewsets.GenericViewSet):
@swagger_auto_schema(responses={status.HTTP_200_OK: RSSSerializer()}) @swagger_auto_schema(responses={status.HTTP_200_OK: RSSSerializer()})
@action(methods=['get'], detail=False, serializer_class=RSSSerializer) @action(methods=['get'], detail=False, serializer_class=RSSSerializer)
def feed_id(self, request: Request): def feed_id(self, request: Request) -> Response:
""" """
Return the RSS feed id needed to get users RSS Feed. Return the RSS feed id needed to get users RSS Feed.
""" """
@@ -532,7 +532,7 @@ class DirectoryViewSet(mixins.UpdateModelMixin, viewsets.GenericViewSet):
lookup_field = 'selector' lookup_field = 'selector'
@swagger_auto_schema(responses={200: DirectorySerializer(many=True)}) @swagger_auto_schema(responses={200: DirectorySerializer(many=True)})
def update(self, request: Request, selector: UUID) -> Response: # pylint: disable=arguments-differ def update(self, request: Request, selector: UUID) -> Response: # pylint: disable=arguments-differ
""" """
This will set the classification of a directory and all it's children. This will set the classification of a directory and all it's children.
""" """
@@ -557,7 +557,7 @@ class DirectoryViewSet(mixins.UpdateModelMixin, viewsets.GenericViewSet):
return Response(response.data) return Response(response.data)
return Response(serializer.errors, status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status.HTTP_400_BAD_REQUEST)
def partial_update(self, request, *args, **kwargs): def partial_update(self, request: Request, *args, **kwargs) -> Response:
""" """
This will set the classification of a directory and none of its children. This will set the classification of a directory and none of its children.
""" """

View File

@@ -5,6 +5,8 @@ addopts = --flake8
max-line-length = 120 max-line-length = 120
ignore = ignore =
* ANN101 * ANN101
* ANN002
* ANN003
# Ignore rules which contradicts black's formatting choices: # Ignore rules which contradicts black's formatting choices:
; * E501 ; * E501
; * W503 ; * W503