From 588c46173e28abb50d7b176ff09b383e0b7b0a09 Mon Sep 17 00:00:00 2001
From: Melody Horn <melody@boringcactus.com>
Date: Mon, 29 Mar 2021 19:10:24 -0600
Subject: allow overriding package names in specific repos

---
 app.py            |  7 ++++---
 repos/__init__.py | 19 +++++++++++++++----
 repos/base.py     |  8 ++++++--
 3 files changed, 25 insertions(+), 9 deletions(-)

diff --git a/app.py b/app.py
index 248bb8d..05e8449 100644
--- a/app.py
+++ b/app.py
@@ -1,4 +1,4 @@
-from flask import Flask, make_response, render_template
+from flask import Flask, make_response, render_template, request
 from jinja2 import select_autoescape
 
 import repos
@@ -8,12 +8,13 @@ app.jinja_options['autoescape'] = select_autoescape(default=True)
 
 @app.route('/')
 def hello_world():
-    return '<a href="/badge/rust.svg">sample badge for Rust</a>'
+    return '<a href="/badge/nushell.svg?crates-io=nu">sample badge for Rust</a>'
 
 
 @app.route('/badge/<package>.svg')
 def badge(package: str):
-    versions = repos.get_versions(package)
+    args = request.args
+    versions = repos.get_versions(package, args)
     newest_version = max(versions.values())
     rendered = render_template('badge.svg.jinja', versions=versions, newest_version=newest_version)
     response = make_response(rendered)
diff --git a/repos/__init__.py b/repos/__init__.py
index 7023071..c836861 100644
--- a/repos/__init__.py
+++ b/repos/__init__.py
@@ -1,7 +1,7 @@
 from typing import Mapping, List
 
 from . import alpine_linux, arch_linux, crates_io
-from .base import Repository, Version
+from .base import Repository, slug, Version
 
 __all__ = [
     'get_versions',
@@ -19,10 +19,21 @@ all_repos: List[Repository] = [
     *repos_from(crates_io),
 ]
 
-def get_versions(package: str) -> Mapping[str, Version]:
+def get_versions(package: str, args: Mapping[str, str]) -> Mapping[str, Version]:
+    special_cases = dict()
+    for repo, name in args.items():
+        special_cases[repo] = name
     result = dict()
     for repo in all_repos:
         repo_versions = repo.get_versions()
-        if package in repo_versions:
-            result[repo.full_name()] = repo_versions[package]
+        if slug(repo.full_name()) in special_cases:
+            package_here = special_cases[slug(repo.full_name())]
+        elif slug(repo.family) in special_cases:
+            package_here = special_cases[slug(repo.family)]
+        elif slug(repo.repo) in special_cases:
+            package_here = special_cases[slug(repo.repo)]
+        else:
+            package_here = package
+        if package_here in repo_versions:
+            result[repo.full_name()] = repo_versions[package_here]
     return result
diff --git a/repos/base.py b/repos/base.py
index 3537f05..c20853a 100644
--- a/repos/base.py
+++ b/repos/base.py
@@ -11,14 +11,18 @@ import semver
 
 __all__ = [
     'Repository',
+    'slug',
     'Version',
 ]
 
 HTTP_DATE = '%a, %d %b %Y %H:%M:%S GMT'
 SLUGIFY = re.compile(r'\W+')
 
-def slug(text: str) -> str:
-    return SLUGIFY.sub('-', text.lower()).strip('-')
+def slug(text: Optional[str]) -> str:
+    if text is None:
+        return ''
+    else:
+        return SLUGIFY.sub('-', text.lower()).strip('-')
 
 @total_ordering
 @dataclass()
-- 
cgit v1.2.3