1package reddit23// TODO: Add tests.45import (6 "context"7 "encoding/base64"8 "encoding/json"9 "fmt"10 "html"11 "io"12 "log"13 "net/http"14 "strings"15 "time"1617 "website-feeds/model"18 "website-feeds/util"19)2021const (22 userAgent = `WebsiteFeeds/0.10.0 by u/psdelong`23 grantRequest = `grant_type=client_credentials`2425 accessTokenURL = `https://www.reddit.com/api/v1/access_token`26 topPostsURL = `https://oauth.reddit.com/r/%s/top?t=week&limit=%d`27 aboutSubredditURL = `https://oauth.reddit.com/api/info?sr_name=%s`28)2930type Client struct {31 client *http.Client32 token *string33 clientID string34 clientSecret string35 subreddit string36}3738type accessTokenResponse struct {39 AccessToken string `json:"access_token"`40}4142func New(43 clientID string,44 clientSecret string,45 subreddit string,46) *Client {47 return &Client{48 client: util.NewHttpClient(),49 token: nil,50 clientID: clientID,51 clientSecret: clientSecret,52 subreddit: subreddit,53 }54}5556func (c *Client) EnsureToken(ctx context.Context) error {57 if c.token == nil {58 token, err := getRedditAccessToken(59 ctx,60 c.client,61 c.clientID,62 c.clientSecret,63 )64 if err != nil {65 return fmt.Errorf("failed to get token: %w", err)66 }6768 c.token = &token69 }7071 return nil72}7374func (c *Client) URL() string {75 return fmt.Sprintf("https://old.reddit.com/r/%s", c.subreddit)76}7778func (c *Client) CommentURL(post model.Post) string {79 return fmt.Sprintf(80 "https://old.reddit.com/r/%s/%s",81 c.subreddit,82 strings.TrimPrefix(post.SiteUniqueID, "t3_"),83 )84}8586func makeRequest[T any](87 ctx context.Context,88 client *http.Client,89 userAgent, token, url string,90) (*T, error) {91 req, err := http.NewRequestWithContext(92 ctx,93 "GET",94 url,95 nil,96 )97 if err != nil {98 return nil, err99 }100101 req.Header.Add("User-Agent", userAgent)102 req.Header.Add("Authorization", fmt.Sprintf("bearer %s", token))103104 resp, err := client.Do(req)105 if err != nil {106 return nil, err107 }108 defer func() {109 if err := resp.Body.Close(); err != nil {110 log.Printf("failed to close body: %s", err)111 }112 }()113114 if resp.StatusCode != http.StatusOK {115 return nil, fmt.Errorf("received non-OK status: %q", resp.Status)116 }117118 var rResp T119 body, err := io.ReadAll(resp.Body)120 if err != nil {121 return nil, err122 }123 err = json.Unmarshal(body, &rResp)124 if err != nil {125 log.Print(string(body))126 return nil, err127 }128129 return &rResp, nil130}131132func (c *Client) FetchPosts(133 ctx context.Context,134 numPosts int,135) ([]model.Post, error) {136 if err := c.EnsureToken(ctx); err != nil {137 return nil, err138 }139140 type RedditResp struct {141 Data struct {142 Children []struct {143 Post struct {144 Id string `json:"name"`145 Title string `json:"title"`146 Upvotes int64 `json:"ups"`147 Created float64 `json:"created_utc"`148 URL string `json:"url"`149 Self bool `json:"is_self"`150 } `json:"data"`151 } `json:"children"`152 } `json:"data"`153 }154155 rResp, err := makeRequest[RedditResp](156 ctx,157 c.client,158 userAgent,159 *c.token,160 fmt.Sprintf(161 topPostsURL,162 c.subreddit,163 numPosts,164 ),165 )166 if err != nil {167 return nil, err168 }169170 posts := make([]model.Post, 0, len(rResp.Data.Children))171 for _, child := range rResp.Data.Children {172 internalPost := child.Post.Self173 var url *string174 if !internalPost {175 // Clone string because it will be overwritten on the next loop.176 str := child.Post.URL177 url = &str178 }179180 post := model.Post{181 SiteUniqueID: child.Post.Id,182 Title: html.UnescapeString(child.Post.Title),183 Score: int(child.Post.Upvotes),184 Created: time.Unix(int64(child.Post.Created), 0),185 URL: url,186 }187 posts = append(posts, post)188 }189190 err = validatePosts(posts)191 if err != nil {192 return nil, err193 }194195 return posts, nil196}197198func (c *Client) DisplayName(ctx context.Context) (string, error) {199 if err := c.EnsureToken(ctx); err != nil {200 return "", err201 }202203 type RedditResp struct {204 Data struct {205 Children []struct {206 Subreddit struct {207 DisplayNamePrefixed string `json:"display_name_prefixed"`208 } `json:"data"`209 } `json:"children"`210 } `json:"data"`211 }212213 rResp, err := makeRequest[RedditResp](214 ctx,215 c.client,216 userAgent,217 *c.token,218 fmt.Sprintf(219 aboutSubredditURL,220 c.subreddit,221 ),222 )223 if err != nil {224 return "", err225 }226227 if len(rResp.Data.Children) != 1 {228 return "", fmt.Errorf(229 "received strange number of subreddit results: expected 1, got %d",230 len(rResp.Data.Children),231 )232 }233234 return rResp.Data.Children[0].Subreddit.DisplayNamePrefixed, nil235}236237func validatePosts(posts []model.Post) error {238 for _, post := range posts {239 if post.SiteUniqueID == "" {240 return fmt.Errorf("post has no ID")241 }242 if post.Created.Equal(time.Unix(0, 0)) {243 return fmt.Errorf("created was 0 for %s", post.SiteUniqueID)244 }245 if post.Title == "" {246 return fmt.Errorf("title was empty for %s", post.SiteUniqueID)247 }248 }249250 return nil251}252253func getRedditAccessToken(254 ctx context.Context,255 client *http.Client,256 clientId, clientSecret string,257) (string, error) {258 req, err := http.NewRequestWithContext(259 ctx,260 "POST",261 accessTokenURL,262 strings.NewReader(grantRequest),263 )264 if err != nil {265 return "", err266 }267268 req.Header.Add("User-Agent", userAgent)269 req.Header.Add(270 "Authorization",271 fmt.Sprintf(272 "Basic %s",273 base64.StdEncoding.EncodeToString(274 fmt.Appendf(nil, "%s:%s", clientId, clientSecret),275 ),276 ),277 )278279 resp, err := client.Do(req)280 if err != nil {281 return "", err282 }283 defer func() {284 if err := resp.Body.Close(); err != nil {285 log.Printf("failed to close body: %s", err)286 }287 }()288289 body, err := io.ReadAll(resp.Body)290 if err != nil {291 return "", err292 }293294 var tokenResp accessTokenResponse295 err = json.NewDecoder(strings.NewReader(string(body))).Decode(&tokenResp)296 if err != nil {297 return "", err298 }299300 if tokenResp.AccessToken == "" {301 return "", fmt.Errorf("could not retrieve token from response")302 }303304 return tokenResp.AccessToken, nil305}