aboutsummaryrefslogtreecommitdiff
path: root/repos/base.py
blob: 1b4c42b47bbf6e520a9315efff4ad4b99cc89394 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from dataclasses import dataclass, asdict as dataclass_asdict
from functools import total_ordering
import json
from pathlib import Path
import re
from typing import Any, Callable, Mapping, Optional

import requests
import semver

from . import db

__all__ = [
    'Repository',
    'slug',
    'Version',
]

HTTP_DATE = '%a, %d %b %Y %H:%M:%S GMT'
SLUGIFY = re.compile(r'\W+')

def slug(text: Optional[str]) -> str:
    if text is None:
        return ''
    else:
        return SLUGIFY.sub('-', text.lower()).strip('-')

@total_ordering
@dataclass()
class Version:
    original: str
    clean: str

    def __str__(self) -> str:
        return self.original

    def __lt__(self, other: Any):
        if not isinstance(other, Version):
            return NotImplemented
        if semver.VersionInfo.isvalid(self.clean) and semver.VersionInfo.isvalid(other.clean):
            return semver.compare(self.clean, other.clean) < 0
        return self.original < other.original

class JSONEncoder(json.JSONEncoder):
    def default(self, o: Any) -> Any:
        if isinstance(o, Version):
            return dataclass_asdict(o)
        return super().default(o)

class JSONDecoder(json.JSONDecoder):
    @staticmethod
    def object_hook(o: dict) -> Any:
        if o.keys() == {'original', 'clean'}:
            return Version(**o)
        return o

    def __init__(self):
        super().__init__(object_hook=self.object_hook)


@dataclass()
class Repository:
    family: Optional[str]
    repo: str
    index_url: str
    parse: Callable[[Path], Mapping[str, Version]]

    def full_name(self):
        if self.family is None:
            return self.repo
        else:
            return f'{self.family} {self.repo}'

    def _cache_dir(self) -> Path:
        if self.family is None:
            return Path('data') / slug(self.repo)
        else:
            return Path('data') / slug(self.family) / slug(self.repo)

    def _cache_file(self, name: str) -> Path:
        return self._cache_dir() / name

    def update(self):
        self._cache_dir().mkdir(parents=True, exist_ok=True)
        headers = dict()

        downloaded_file = self._cache_file('downloaded')

        mtime_file = self._cache_file('last-modified')
        if mtime_file.exists():
            mtime = mtime_file.read_text()
            headers['If-Modified-Since'] = mtime

        etag_file = self._cache_file('etag')
        if etag_file.exists():
            etag = etag_file.read_text()
            headers['If-None-Match'] = etag

        response = requests.get(self.index_url, headers=headers, stream=True)
        if response.status_code != requests.codes.not_modified:
            response.raise_for_status()
            print('Re-downloading', self.full_name())
            with downloaded_file.open('wb') as f:
                for chunk in response.iter_content(chunk_size=256):
                    f.write(chunk)
            if 'Last-Modified' in response.headers:
                set_mtime = response.headers['Last-Modified']
                mtime_file.write_text(set_mtime)
            if 'ETag' in response.headers:
                set_etag = response.headers['ETag']
                etag_file.write_text(set_etag)

            parsed_data = self.parse(downloaded_file)
            db.write(self.full_name(), parsed_data)

    def get_version(self, package_name: str) -> Optional[Version]:
        db_result = db.read(self.full_name(), package_name)
        if db_result is None:
            return None
        return Version(**db_result)