website-feeds

Make RSS feeds of your favorite "vote on posts" websites

git clone https://code.pdelong.com/website-feeds.git

  1package reddit
  2
  3// TODO: Add tests.
  4
  5import (
  6	"context"
  7	"encoding/base64"
  8	"encoding/json"
  9	"fmt"
 10	"html"
 11	"io"
 12	"log"
 13	"net/http"
 14	"strings"
 15	"time"
 16
 17	"website-feeds/model"
 18	"website-feeds/util"
 19)
 20
 21const (
 22	userAgent    = `WebsiteFeeds/0.10.0 by u/psdelong`
 23	grantRequest = `grant_type=client_credentials`
 24
 25	accessTokenURL    = `https://www.reddit.com/api/v1/access_token`
 26	topPostsURL       = `https://oauth.reddit.com/r/%s/top?t=week&limit=%d`
 27	aboutSubredditURL = `https://oauth.reddit.com/api/info?sr_name=%s`
 28)
 29
 30type Client struct {
 31	client       *http.Client
 32	token        *string
 33	clientID     string
 34	clientSecret string
 35	subreddit    string
 36}
 37
 38type accessTokenResponse struct {
 39	AccessToken string `json:"access_token"`
 40}
 41
 42func New(
 43	clientID string,
 44	clientSecret string,
 45	subreddit string,
 46) *Client {
 47	return &Client{
 48		client:       util.NewHttpClient(),
 49		token:        nil,
 50		clientID:     clientID,
 51		clientSecret: clientSecret,
 52		subreddit:    subreddit,
 53	}
 54}
 55
 56func (c *Client) EnsureToken(ctx context.Context) error {
 57	if c.token == nil {
 58		token, err := getRedditAccessToken(
 59			ctx,
 60			c.client,
 61			c.clientID,
 62			c.clientSecret,
 63		)
 64		if err != nil {
 65			return fmt.Errorf("failed to get token: %w", err)
 66		}
 67
 68		c.token = &token
 69	}
 70
 71	return nil
 72}
 73
 74func (c *Client) URL() string {
 75	return fmt.Sprintf("https://old.reddit.com/r/%s", c.subreddit)
 76}
 77
 78func (c *Client) CommentURL(post model.Post) string {
 79	return fmt.Sprintf(
 80		"https://old.reddit.com/r/%s/%s",
 81		c.subreddit,
 82		strings.TrimPrefix(post.SiteUniqueID, "t3_"),
 83	)
 84}
 85
 86func makeRequest[T any](
 87	ctx context.Context,
 88	client *http.Client,
 89	userAgent, token, url string,
 90) (*T, error) {
 91	req, err := http.NewRequestWithContext(
 92		ctx,
 93		"GET",
 94		url,
 95		nil,
 96	)
 97	if err != nil {
 98		return nil, err
 99	}
100
101	req.Header.Add("User-Agent", userAgent)
102	req.Header.Add("Authorization", fmt.Sprintf("bearer %s", token))
103
104	resp, err := client.Do(req)
105	if err != nil {
106		return nil, err
107	}
108	defer func() {
109		if err := resp.Body.Close(); err != nil {
110			log.Printf("failed to close body: %s", err)
111		}
112	}()
113
114	if resp.StatusCode != http.StatusOK {
115		return nil, fmt.Errorf("received non-OK status: %q", resp.Status)
116	}
117
118	var rResp T
119	body, err := io.ReadAll(resp.Body)
120	if err != nil {
121		return nil, err
122	}
123	err = json.Unmarshal(body, &rResp)
124	if err != nil {
125		log.Print(string(body))
126		return nil, err
127	}
128
129	return &rResp, nil
130}
131
132func (c *Client) FetchPosts(
133	ctx context.Context,
134	numPosts int,
135) ([]model.Post, error) {
136	if err := c.EnsureToken(ctx); err != nil {
137		return nil, err
138	}
139
140	type RedditResp struct {
141		Data struct {
142			Children []struct {
143				Post struct {
144					Id      string  `json:"name"`
145					Title   string  `json:"title"`
146					Upvotes int64   `json:"ups"`
147					Created float64 `json:"created_utc"`
148					URL     string  `json:"url"`
149					Self    bool    `json:"is_self"`
150				} `json:"data"`
151			} `json:"children"`
152		} `json:"data"`
153	}
154
155	rResp, err := makeRequest[RedditResp](
156		ctx,
157		c.client,
158		userAgent,
159		*c.token,
160		fmt.Sprintf(
161			topPostsURL,
162			c.subreddit,
163			numPosts,
164		),
165	)
166	if err != nil {
167		return nil, err
168	}
169
170	posts := make([]model.Post, 0, len(rResp.Data.Children))
171	for _, child := range rResp.Data.Children {
172		internalPost := child.Post.Self
173		var url *string
174		if !internalPost {
175			// Clone string because it will be overwritten on the next loop.
176			str := child.Post.URL
177			url = &str
178		}
179
180		post := model.Post{
181			SiteUniqueID: child.Post.Id,
182			Title:        html.UnescapeString(child.Post.Title),
183			Score:        int(child.Post.Upvotes),
184			Created:      time.Unix(int64(child.Post.Created), 0),
185			URL:          url,
186		}
187		posts = append(posts, post)
188	}
189
190	err = validatePosts(posts)
191	if err != nil {
192		return nil, err
193	}
194
195	return posts, nil
196}
197
198func (c *Client) DisplayName(ctx context.Context) (string, error) {
199	if err := c.EnsureToken(ctx); err != nil {
200		return "", err
201	}
202
203	type RedditResp struct {
204		Data struct {
205			Children []struct {
206				Subreddit struct {
207					DisplayNamePrefixed string `json:"display_name_prefixed"`
208				} `json:"data"`
209			} `json:"children"`
210		} `json:"data"`
211	}
212
213	rResp, err := makeRequest[RedditResp](
214		ctx,
215		c.client,
216		userAgent,
217		*c.token,
218		fmt.Sprintf(
219			aboutSubredditURL,
220			c.subreddit,
221		),
222	)
223	if err != nil {
224		return "", err
225	}
226
227	if len(rResp.Data.Children) != 1 {
228		return "", fmt.Errorf(
229			"received strange number of subreddit results: expected 1, got %d",
230			len(rResp.Data.Children),
231		)
232	}
233
234	return rResp.Data.Children[0].Subreddit.DisplayNamePrefixed, nil
235}
236
237func validatePosts(posts []model.Post) error {
238	for _, post := range posts {
239		if post.SiteUniqueID == "" {
240			return fmt.Errorf("post has no ID")
241		}
242		if post.Created.Equal(time.Unix(0, 0)) {
243			return fmt.Errorf("created was 0 for %s", post.SiteUniqueID)
244		}
245		if post.Title == "" {
246			return fmt.Errorf("title was empty for %s", post.SiteUniqueID)
247		}
248	}
249
250	return nil
251}
252
253func getRedditAccessToken(
254	ctx context.Context,
255	client *http.Client,
256	clientId, clientSecret string,
257) (string, error) {
258	req, err := http.NewRequestWithContext(
259		ctx,
260		"POST",
261		accessTokenURL,
262		strings.NewReader(grantRequest),
263	)
264	if err != nil {
265		return "", err
266	}
267
268	req.Header.Add("User-Agent", userAgent)
269	req.Header.Add(
270		"Authorization",
271		fmt.Sprintf(
272			"Basic %s",
273			base64.StdEncoding.EncodeToString(
274				fmt.Appendf(nil, "%s:%s", clientId, clientSecret),
275			),
276		),
277	)
278
279	resp, err := client.Do(req)
280	if err != nil {
281		return "", err
282	}
283	defer func() {
284		if err := resp.Body.Close(); err != nil {
285			log.Printf("failed to close body: %s", err)
286		}
287	}()
288
289	body, err := io.ReadAll(resp.Body)
290	if err != nil {
291		return "", err
292	}
293
294	var tokenResp accessTokenResponse
295	err = json.NewDecoder(strings.NewReader(string(body))).Decode(&tokenResp)
296	if err != nil {
297		return "", err
298	}
299
300	if tokenResp.AccessToken == "" {
301		return "", fmt.Errorf("could not retrieve token from response")
302	}
303
304	return tokenResp.AccessToken, nil
305}