blob: b14aa64c0462afda46050d1d2fe5941f2056c975
1 | /* |
2 | * Copyright (C) 2010-2011 Kevin Stone |
3 | * Copyright (C) 2016 Paul B Mahol |
4 | * |
5 | * This file is part of FFmpeg. |
6 | * |
7 | * FFmpeg is free software; you can redistribute it and/or modify |
8 | * it under the terms of the GNU General Public License as published by |
9 | * the Free Software Foundation; either version 2 of the License, or |
10 | * (at your option) any later version. |
11 | * |
12 | * FFmpeg is distributed in the hope that it will be useful, |
13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
15 | * GNU General Public License for more details. |
16 | * |
17 | * You should have received a copy of the GNU General Public License along |
18 | * with FFmpeg; if not, write to the Free Software Foundation, Inc., |
19 | * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. |
20 | */ |
21 | |
22 | #include <float.h> |
23 | |
24 | #include "libavutil/common.h" |
25 | #include "libavutil/float_dsp.h" |
26 | #include "libavutil/imgutils.h" |
27 | #include "libavutil/opt.h" |
28 | #include "libavutil/pixdesc.h" |
29 | #include "avfilter.h" |
30 | #include "formats.h" |
31 | #include "internal.h" |
32 | #include "video.h" |
33 | |
34 | typedef struct FrameData { |
35 | uint8_t *paddedp[3]; |
36 | int padded_stride[3]; |
37 | int padded_width[3]; |
38 | int padded_height[3]; |
39 | |
40 | uint8_t *dstp[3]; |
41 | int dst_stride[3]; |
42 | |
43 | int field[3]; |
44 | |
45 | int32_t *lcount[3]; |
46 | float *input; |
47 | float *temp; |
48 | } FrameData; |
49 | |
50 | typedef struct NNEDIContext { |
51 | const AVClass *class; |
52 | |
53 | char *weights_file; |
54 | |
55 | AVFrame *src; |
56 | AVFrame *second; |
57 | AVFrame *dst; |
58 | int eof; |
59 | int64_t cur_pts; |
60 | |
61 | AVFloatDSPContext *fdsp; |
62 | int nb_planes; |
63 | int linesize[4]; |
64 | int planeheight[4]; |
65 | |
66 | float *weights0; |
67 | float *weights1[2]; |
68 | int asize; |
69 | int nns; |
70 | int xdia; |
71 | int ydia; |
72 | |
73 | // Parameters |
74 | int deint; |
75 | int field; |
76 | int process_plane; |
77 | int nsize; |
78 | int nnsparam; |
79 | int qual; |
80 | int etype; |
81 | int pscrn; |
82 | int fapprox; |
83 | |
84 | int max_value; |
85 | |
86 | void (*copy_pad)(const AVFrame *, FrameData *, struct NNEDIContext *, int); |
87 | void (*evalfunc_0)(struct NNEDIContext *, FrameData *); |
88 | void (*evalfunc_1)(struct NNEDIContext *, FrameData *); |
89 | |
90 | // Functions used in evalfunc_0 |
91 | void (*readpixels)(const uint8_t *, const int, float *); |
92 | void (*compute_network0)(struct NNEDIContext *s, const float *, const float *, uint8_t *); |
93 | int32_t (*process_line0)(const uint8_t *, int, uint8_t *, const uint8_t *, const int, const int, const int); |
94 | |
95 | // Functions used in evalfunc_1 |
96 | void (*extract)(const uint8_t *, const int, const int, const int, float *, float *); |
97 | void (*dot_prod)(struct NNEDIContext *, const float *, const float *, float *, const int, const int, const float *); |
98 | void (*expfunc)(float *, const int); |
99 | void (*wae5)(const float *, const int, float *); |
100 | |
101 | FrameData frame_data; |
102 | } NNEDIContext; |
103 | |
104 | #define OFFSET(x) offsetof(NNEDIContext, x) |
105 | #define FLAGS AV_OPT_FLAG_VIDEO_PARAM|AV_OPT_FLAG_FILTERING_PARAM |
106 | |
107 | static const AVOption nnedi_options[] = { |
108 | {"weights", "set weights file", OFFSET(weights_file), AV_OPT_TYPE_STRING, {.str="nnedi3_weights.bin"}, 0, 0, FLAGS }, |
109 | {"deint", "set which frames to deinterlace", OFFSET(deint), AV_OPT_TYPE_INT, {.i64=0}, 0, 1, FLAGS, "deint" }, |
110 | {"all", "deinterlace all frames", 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "deint" }, |
111 | {"interlaced", "only deinterlace frames marked as interlaced", 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "deint" }, |
112 | {"field", "set mode of operation", OFFSET(field), AV_OPT_TYPE_INT, {.i64=-1}, -2, 3, FLAGS, "field" }, |
113 | {"af", "use frame flags, both fields", 0, AV_OPT_TYPE_CONST, {.i64=-2}, 0, 0, FLAGS, "field" }, |
114 | {"a", "use frame flags, single field", 0, AV_OPT_TYPE_CONST, {.i64=-1}, 0, 0, FLAGS, "field" }, |
115 | {"t", "use top field only", 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "field" }, |
116 | {"b", "use bottom field only", 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "field" }, |
117 | {"tf", "use both fields, top first", 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "field" }, |
118 | {"bf", "use both fields, bottom first", 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, FLAGS, "field" }, |
119 | {"planes", "set which planes to process", OFFSET(process_plane), AV_OPT_TYPE_INT, {.i64=7}, 0, 7, FLAGS }, |
120 | {"nsize", "set size of local neighborhood around each pixel, used by the predictor neural network", OFFSET(nsize), AV_OPT_TYPE_INT, {.i64=6}, 0, 6, FLAGS, "nsize" }, |
121 | {"s8x6", NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "nsize" }, |
122 | {"s16x6", NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "nsize" }, |
123 | {"s32x6", NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "nsize" }, |
124 | {"s48x6", NULL, 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, FLAGS, "nsize" }, |
125 | {"s8x4", NULL, 0, AV_OPT_TYPE_CONST, {.i64=4}, 0, 0, FLAGS, "nsize" }, |
126 | {"s16x4", NULL, 0, AV_OPT_TYPE_CONST, {.i64=5}, 0, 0, FLAGS, "nsize" }, |
127 | {"s32x4", NULL, 0, AV_OPT_TYPE_CONST, {.i64=6}, 0, 0, FLAGS, "nsize" }, |
128 | {"nns", "set number of neurons in predictor neural network", OFFSET(nnsparam), AV_OPT_TYPE_INT, {.i64=1}, 0, 4, FLAGS, "nns" }, |
129 | {"n16", NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "nns" }, |
130 | {"n32", NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "nns" }, |
131 | {"n64", NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "nns" }, |
132 | {"n128", NULL, 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, FLAGS, "nns" }, |
133 | {"n256", NULL, 0, AV_OPT_TYPE_CONST, {.i64=4}, 0, 0, FLAGS, "nns" }, |
134 | {"qual", "set quality", OFFSET(qual), AV_OPT_TYPE_INT, {.i64=1}, 1, 2, FLAGS, "qual" }, |
135 | {"fast", NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "qual" }, |
136 | {"slow", NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "qual" }, |
137 | {"etype", "set which set of weights to use in the predictor", OFFSET(etype), AV_OPT_TYPE_INT, {.i64=0}, 0, 1, FLAGS, "etype" }, |
138 | {"a", "weights trained to minimize absolute error", 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "etype" }, |
139 | {"s", "weights trained to minimize squared error", 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "etype" }, |
140 | {"pscrn", "set prescreening", OFFSET(pscrn), AV_OPT_TYPE_INT, {.i64=2}, 0, 2, FLAGS, "pscrn" }, |
141 | {"none", NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "pscrn" }, |
142 | {"original", NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "pscrn" }, |
143 | {"new", NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "pscrn" }, |
144 | {"fapprox", NULL, OFFSET(fapprox), AV_OPT_TYPE_INT, {.i64=0}, 0, 3, FLAGS }, |
145 | { NULL } |
146 | }; |
147 | |
148 | AVFILTER_DEFINE_CLASS(nnedi); |
149 | |
150 | static int config_input(AVFilterLink *inlink) |
151 | { |
152 | AVFilterContext *ctx = inlink->dst; |
153 | NNEDIContext *s = ctx->priv; |
154 | const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format); |
155 | int ret; |
156 | |
157 | s->nb_planes = av_pix_fmt_count_planes(inlink->format); |
158 | if ((ret = av_image_fill_linesizes(s->linesize, inlink->format, inlink->w)) < 0) |
159 | return ret; |
160 | |
161 | s->planeheight[1] = s->planeheight[2] = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h); |
162 | s->planeheight[0] = s->planeheight[3] = inlink->h; |
163 | |
164 | return 0; |
165 | } |
166 | |
167 | static int config_output(AVFilterLink *outlink) |
168 | { |
169 | AVFilterContext *ctx = outlink->src; |
170 | NNEDIContext *s = ctx->priv; |
171 | |
172 | outlink->time_base.num = ctx->inputs[0]->time_base.num; |
173 | outlink->time_base.den = ctx->inputs[0]->time_base.den * 2; |
174 | outlink->w = ctx->inputs[0]->w; |
175 | outlink->h = ctx->inputs[0]->h; |
176 | |
177 | if (s->field > 1 || s->field == -2) |
178 | outlink->frame_rate = av_mul_q(ctx->inputs[0]->frame_rate, |
179 | (AVRational){2, 1}); |
180 | |
181 | return 0; |
182 | } |
183 | |
184 | static int query_formats(AVFilterContext *ctx) |
185 | { |
186 | static const enum AVPixelFormat pix_fmts[] = { |
187 | AV_PIX_FMT_YUV410P, AV_PIX_FMT_YUV411P, |
188 | AV_PIX_FMT_YUV420P, AV_PIX_FMT_YUV422P, |
189 | AV_PIX_FMT_YUV440P, AV_PIX_FMT_YUV444P, |
190 | AV_PIX_FMT_YUVJ444P, AV_PIX_FMT_YUVJ440P, |
191 | AV_PIX_FMT_YUVJ422P, AV_PIX_FMT_YUVJ420P, |
192 | AV_PIX_FMT_YUVJ411P, |
193 | AV_PIX_FMT_GBRP, |
194 | AV_PIX_FMT_GRAY8, |
195 | AV_PIX_FMT_NONE |
196 | }; |
197 | |
198 | AVFilterFormats *fmts_list = ff_make_format_list(pix_fmts); |
199 | if (!fmts_list) |
200 | return AVERROR(ENOMEM); |
201 | return ff_set_common_formats(ctx, fmts_list); |
202 | } |
203 | |
204 | static void copy_pad(const AVFrame *src, FrameData *frame_data, NNEDIContext *s, int fn) |
205 | { |
206 | const int off = 1 - fn; |
207 | int plane, y, x; |
208 | |
209 | for (plane = 0; plane < s->nb_planes; plane++) { |
210 | const uint8_t *srcp = (const uint8_t *)src->data[plane]; |
211 | uint8_t *dstp = (uint8_t *)frame_data->paddedp[plane]; |
212 | |
213 | const int src_stride = src->linesize[plane]; |
214 | const int dst_stride = frame_data->padded_stride[plane]; |
215 | |
216 | const int src_height = s->planeheight[plane]; |
217 | const int dst_height = frame_data->padded_height[plane]; |
218 | |
219 | const int src_width = s->linesize[plane]; |
220 | const int dst_width = frame_data->padded_width[plane]; |
221 | |
222 | int c = 4; |
223 | |
224 | if (!(s->process_plane & (1 << plane))) |
225 | continue; |
226 | |
227 | // Copy. |
228 | for (y = off; y < src_height; y += 2) |
229 | memcpy(dstp + 32 + (6 + y) * dst_stride, |
230 | srcp + y * src_stride, |
231 | src_width * sizeof(uint8_t)); |
232 | |
233 | // And pad. |
234 | dstp += (6 + off) * dst_stride; |
235 | for (y = 6 + off; y < dst_height - 6; y += 2) { |
236 | int c = 2; |
237 | |
238 | for (x = 0; x < 32; x++) |
239 | dstp[x] = dstp[64 - x]; |
240 | |
241 | for (x = dst_width - 32; x < dst_width; x++, c += 2) |
242 | dstp[x] = dstp[x - c]; |
243 | |
244 | dstp += dst_stride * 2; |
245 | } |
246 | |
247 | dstp = (uint8_t *)frame_data->paddedp[plane]; |
248 | for (y = off; y < 6; y += 2) |
249 | memcpy(dstp + y * dst_stride, |
250 | dstp + (12 + 2 * off - y) * dst_stride, |
251 | dst_width * sizeof(uint8_t)); |
252 | |
253 | for (y = dst_height - 6 + off; y < dst_height; y += 2, c += 4) |
254 | memcpy(dstp + y * dst_stride, |
255 | dstp + (y - c) * dst_stride, |
256 | dst_width * sizeof(uint8_t)); |
257 | } |
258 | } |
259 | |
260 | static void elliott(float *data, const int n) |
261 | { |
262 | int i; |
263 | |
264 | for (i = 0; i < n; i++) |
265 | data[i] = data[i] / (1.0f + FFABS(data[i])); |
266 | } |
267 | |
268 | static void dot_prod(NNEDIContext *s, const float *data, const float *weights, float *vals, const int n, const int len, const float *scale) |
269 | { |
270 | int i; |
271 | |
272 | for (i = 0; i < n; i++) { |
273 | float sum; |
274 | |
275 | sum = s->fdsp->scalarproduct_float(data, &weights[i * len], len); |
276 | |
277 | vals[i] = sum * scale[0] + weights[n * len + i]; |
278 | } |
279 | } |
280 | |
281 | static void dot_prods(NNEDIContext *s, const float *dataf, const float *weightsf, float *vals, const int n, const int len, const float *scale) |
282 | { |
283 | const int16_t *data = (int16_t *)dataf; |
284 | const int16_t *weights = (int16_t *)weightsf; |
285 | const float *wf = (float *)&weights[n * len]; |
286 | int i, j; |
287 | |
288 | for (i = 0; i < n; i++) { |
289 | int sum = 0, off = ((i >> 2) << 3) + (i & 3); |
290 | for (j = 0; j < len; j++) |
291 | sum += data[j] * weights[i * len + j]; |
292 | |
293 | vals[i] = sum * wf[off] * scale[0] + wf[off + 4]; |
294 | } |
295 | } |
296 | |
297 | static void compute_network0(NNEDIContext *s, const float *input, const float *weights, uint8_t *d) |
298 | { |
299 | float t, temp[12], scale = 1.0f; |
300 | |
301 | dot_prod(s, input, weights, temp, 4, 48, &scale); |
302 | t = temp[0]; |
303 | elliott(temp, 4); |
304 | temp[0] = t; |
305 | dot_prod(s, temp, weights + 4 * 49, temp + 4, 4, 4, &scale); |
306 | elliott(temp + 4, 4); |
307 | dot_prod(s, temp, weights + 4 * 49 + 4 * 5, temp + 8, 4, 8, &scale); |
308 | if (FFMAX(temp[10], temp[11]) <= FFMAX(temp[8], temp[9])) |
309 | d[0] = 1; |
310 | else |
311 | d[0] = 0; |
312 | } |
313 | |
314 | static void compute_network0_i16(NNEDIContext *s, const float *inputf, const float *weightsf, uint8_t *d) |
315 | { |
316 | const float *wf = weightsf + 2 * 48; |
317 | float t, temp[12], scale = 1.0f; |
318 | |
319 | dot_prods(s, inputf, weightsf, temp, 4, 48, &scale); |
320 | t = temp[0]; |
321 | elliott(temp, 4); |
322 | temp[0] = t; |
323 | dot_prod(s, temp, wf + 8, temp + 4, 4, 4, &scale); |
324 | elliott(temp + 4, 4); |
325 | dot_prod(s, temp, wf + 8 + 4 * 5, temp + 8, 4, 8, &scale); |
326 | if (FFMAX(temp[10], temp[11]) <= FFMAX(temp[8], temp[9])) |
327 | d[0] = 1; |
328 | else |
329 | d[0] = 0; |
330 | } |
331 | |
332 | static void pixel2float48(const uint8_t *t8, const int pitch, float *p) |
333 | { |
334 | const uint8_t *t = (const uint8_t *)t8; |
335 | int y, x; |
336 | |
337 | for (y = 0; y < 4; y++) |
338 | for (x = 0; x < 12; x++) |
339 | p[y * 12 + x] = t[y * pitch * 2 + x]; |
340 | } |
341 | |
342 | static void byte2word48(const uint8_t *t, const int pitch, float *pf) |
343 | { |
344 | int16_t *p = (int16_t *)pf; |
345 | int y, x; |
346 | |
347 | for (y = 0; y < 4; y++) |
348 | for (x = 0; x < 12; x++) |
349 | p[y * 12 + x] = t[y * pitch * 2 + x]; |
350 | } |
351 | |
352 | static int32_t process_line0(const uint8_t *tempu, int width, uint8_t *dstp8, const uint8_t *src3p8, const int src_pitch, const int max_value, const int chroma) |
353 | { |
354 | uint8_t *dstp = (uint8_t *)dstp8; |
355 | const uint8_t *src3p = (const uint8_t *)src3p8; |
356 | int minimum = 0; |
357 | int maximum = max_value - 1; // Technically the -1 is only needed for 8 and 16 bit input. |
358 | int count = 0, x; |
359 | for (x = 0; x < width; x++) { |
360 | if (tempu[x]) { |
361 | int tmp = 19 * (src3p[x + src_pitch * 2] + src3p[x + src_pitch * 4]) - 3 * (src3p[x] + src3p[x + src_pitch * 6]); |
362 | tmp /= 32; |
363 | dstp[x] = FFMAX(FFMIN(tmp, maximum), minimum); |
364 | } else { |
365 | dstp[x] = 255; |
366 | count++; |
367 | } |
368 | } |
369 | return count; |
370 | } |
371 | |
372 | // new prescreener functions |
373 | static void byte2word64(const uint8_t *t, const int pitch, float *p) |
374 | { |
375 | int16_t *ps = (int16_t *)p; |
376 | int y, x; |
377 | |
378 | for (y = 0; y < 4; y++) |
379 | for (x = 0; x < 16; x++) |
380 | ps[y * 16 + x] = t[y * pitch * 2 + x]; |
381 | } |
382 | |
383 | static void compute_network0new(NNEDIContext *s, const float *datai, const float *weights, uint8_t *d) |
384 | { |
385 | int16_t *data = (int16_t *)datai; |
386 | int16_t *ws = (int16_t *)weights; |
387 | float *wf = (float *)&ws[4 * 64]; |
388 | float vals[8]; |
389 | int mask, i, j; |
390 | |
391 | for (i = 0; i < 4; i++) { |
392 | int sum = 0; |
393 | float t; |
394 | |
395 | for (j = 0; j < 64; j++) |
396 | sum += data[j] * ws[(i << 3) + ((j >> 3) << 5) + (j & 7)]; |
397 | t = sum * wf[i] + wf[4 + i]; |
398 | vals[i] = t / (1.0f + FFABS(t)); |
399 | } |
400 | |
401 | for (i = 0; i < 4; i++) { |
402 | float sum = 0.0f; |
403 | |
404 | for (j = 0; j < 4; j++) |
405 | sum += vals[j] * wf[8 + i + (j << 2)]; |
406 | vals[4 + i] = sum + wf[8 + 16 + i]; |
407 | } |
408 | |
409 | mask = 0; |
410 | for (i = 0; i < 4; i++) { |
411 | if (vals[4 + i] > 0.0f) |
412 | mask |= (0x1 << (i << 3)); |
413 | } |
414 | |
415 | ((int *)d)[0] = mask; |
416 | } |
417 | |
418 | static void evalfunc_0(NNEDIContext *s, FrameData *frame_data) |
419 | { |
420 | float *input = frame_data->input; |
421 | const float *weights0 = s->weights0; |
422 | float *temp = frame_data->temp; |
423 | uint8_t *tempu = (uint8_t *)temp; |
424 | int plane, x, y; |
425 | |
426 | // And now the actual work. |
427 | for (plane = 0; plane < s->nb_planes; plane++) { |
428 | const uint8_t *srcp = (const uint8_t *)frame_data->paddedp[plane]; |
429 | const int src_stride = frame_data->padded_stride[plane] / sizeof(uint8_t); |
430 | |
431 | const int width = frame_data->padded_width[plane]; |
432 | const int height = frame_data->padded_height[plane]; |
433 | |
434 | uint8_t *dstp = (uint8_t *)frame_data->dstp[plane]; |
435 | const int dst_stride = frame_data->dst_stride[plane] / sizeof(uint8_t); |
436 | const uint8_t *src3p; |
437 | int ystart, ystop; |
438 | int32_t *lcount; |
439 | |
440 | if (!(s->process_plane & (1 << plane))) |
441 | continue; |
442 | |
443 | for (y = 1 - frame_data->field[plane]; y < height - 12; y += 2) { |
444 | memcpy(dstp + y * dst_stride, |
445 | srcp + 32 + (6 + y) * src_stride, |
446 | (width - 64) * sizeof(uint8_t)); |
447 | |
448 | } |
449 | |
450 | ystart = 6 + frame_data->field[plane]; |
451 | ystop = height - 6; |
452 | srcp += ystart * src_stride; |
453 | dstp += (ystart - 6) * dst_stride - 32; |
454 | src3p = srcp - src_stride * 3; |
455 | lcount = frame_data->lcount[plane] - 6; |
456 | |
457 | if (s->pscrn == 1) { // original |
458 | for (y = ystart; y < ystop; y += 2) { |
459 | for (x = 32; x < width - 32; x++) { |
460 | s->readpixels((const uint8_t *)(src3p + x - 5), src_stride, input); |
461 | s->compute_network0(s, input, weights0, tempu+x); |
462 | } |
463 | lcount[y] += s->process_line0(tempu + 32, width - 64, (uint8_t *)(dstp + 32), (const uint8_t *)(src3p + 32), src_stride, s->max_value, plane); |
464 | src3p += src_stride * 2; |
465 | dstp += dst_stride * 2; |
466 | } |
467 | } else if (s->pscrn > 1) { // new |
468 | for (y = ystart; y < ystop; y += 2) { |
469 | for (x = 32; x < width - 32; x += 4) { |
470 | s->readpixels((const uint8_t *)(src3p + x - 6), src_stride, input); |
471 | s->compute_network0(s, input, weights0, tempu + x); |
472 | } |
473 | lcount[y] += s->process_line0(tempu + 32, width - 64, (uint8_t *)(dstp + 32), (const uint8_t *)(src3p + 32), src_stride, s->max_value, plane); |
474 | src3p += src_stride * 2; |
475 | dstp += dst_stride * 2; |
476 | } |
477 | } else { // no prescreening |
478 | for (y = ystart; y < ystop; y += 2) { |
479 | memset(dstp + 32, 255, (width - 64) * sizeof(uint8_t)); |
480 | lcount[y] += width - 64; |
481 | dstp += dst_stride * 2; |
482 | } |
483 | } |
484 | } |
485 | } |
486 | |
487 | static void extract_m8(const uint8_t *srcp8, const int stride, const int xdia, const int ydia, float *mstd, float *input) |
488 | { |
489 | // uint8_t or uint16_t or float |
490 | const uint8_t *srcp = (const uint8_t *)srcp8; |
491 | float scale; |
492 | double tmp; |
493 | |
494 | // int32_t or int64_t or double |
495 | int64_t sum = 0, sumsq = 0; |
496 | int y, x; |
497 | |
498 | for (y = 0; y < ydia; y++) { |
499 | const uint8_t *srcpT = srcp + y * stride * 2; |
500 | |
501 | for (x = 0; x < xdia; x++) { |
502 | sum += srcpT[x]; |
503 | sumsq += (uint32_t)srcpT[x] * (uint32_t)srcpT[x]; |
504 | input[x] = srcpT[x]; |
505 | } |
506 | input += xdia; |
507 | } |
508 | scale = 1.0f / (xdia * ydia); |
509 | mstd[0] = sum * scale; |
510 | tmp = (double)sumsq * scale - (double)mstd[0] * mstd[0]; |
511 | mstd[3] = 0.0f; |
512 | if (tmp <= FLT_EPSILON) |
513 | mstd[1] = mstd[2] = 0.0f; |
514 | else { |
515 | mstd[1] = sqrt(tmp); |
516 | mstd[2] = 1.0f / mstd[1]; |
517 | } |
518 | } |
519 | |
520 | static void extract_m8_i16(const uint8_t *srcp, const int stride, const int xdia, const int ydia, float *mstd, float *inputf) |
521 | { |
522 | int16_t *input = (int16_t *)inputf; |
523 | float scale; |
524 | int sum = 0, sumsq = 0; |
525 | int y, x; |
526 | |
527 | for (y = 0; y < ydia; y++) { |
528 | const uint8_t *srcpT = srcp + y * stride * 2; |
529 | for (x = 0; x < xdia; x++) { |
530 | sum += srcpT[x]; |
531 | sumsq += srcpT[x] * srcpT[x]; |
532 | input[x] = srcpT[x]; |
533 | } |
534 | input += xdia; |
535 | } |
536 | scale = 1.0f / (float)(xdia * ydia); |
537 | mstd[0] = sum * scale; |
538 | mstd[1] = sumsq * scale - mstd[0] * mstd[0]; |
539 | mstd[3] = 0.0f; |
540 | if (mstd[1] <= FLT_EPSILON) |
541 | mstd[1] = mstd[2] = 0.0f; |
542 | else { |
543 | mstd[1] = sqrt(mstd[1]); |
544 | mstd[2] = 1.0f / mstd[1]; |
545 | } |
546 | } |
547 | |
548 | |
549 | static const float exp_lo = -80.0f; |
550 | static const float exp_hi = +80.0f; |
551 | |
552 | static void e2_m16(float *s, const int n) |
553 | { |
554 | int i; |
555 | |
556 | for (i = 0; i < n; i++) |
557 | s[i] = exp(av_clipf(s[i], exp_lo, exp_hi)); |
558 | } |
559 | |
560 | const float min_weight_sum = 1e-10f; |
561 | |
562 | static void weighted_avg_elliott_mul5_m16(const float *w, const int n, float *mstd) |
563 | { |
564 | float vsum = 0.0f, wsum = 0.0f; |
565 | int i; |
566 | |
567 | for (i = 0; i < n; i++) { |
568 | vsum += w[i] * (w[n + i] / (1.0f + FFABS(w[n + i]))); |
569 | wsum += w[i]; |
570 | } |
571 | if (wsum > min_weight_sum) |
572 | mstd[3] += ((5.0f * vsum) / wsum) * mstd[1] + mstd[0]; |
573 | else |
574 | mstd[3] += mstd[0]; |
575 | } |
576 | |
577 | |
578 | static void evalfunc_1(NNEDIContext *s, FrameData *frame_data) |
579 | { |
580 | float *input = frame_data->input; |
581 | float *temp = frame_data->temp; |
582 | float **weights1 = s->weights1; |
583 | const int qual = s->qual; |
584 | const int asize = s->asize; |
585 | const int nns = s->nns; |
586 | const int xdia = s->xdia; |
587 | const int xdiad2m1 = (xdia / 2) - 1; |
588 | const int ydia = s->ydia; |
589 | const float scale = 1.0f / (float)qual; |
590 | int plane, y, x, i; |
591 | |
592 | for (plane = 0; plane < s->nb_planes; plane++) { |
593 | const uint8_t *srcp = (const uint8_t *)frame_data->paddedp[plane]; |
594 | const int src_stride = frame_data->padded_stride[plane] / sizeof(uint8_t); |
595 | |
596 | const int width = frame_data->padded_width[plane]; |
597 | const int height = frame_data->padded_height[plane]; |
598 | |
599 | uint8_t *dstp = (uint8_t *)frame_data->dstp[plane]; |
600 | const int dst_stride = frame_data->dst_stride[plane] / sizeof(uint8_t); |
601 | |
602 | const int ystart = frame_data->field[plane]; |
603 | const int ystop = height - 12; |
604 | const uint8_t *srcpp; |
605 | |
606 | if (!(s->process_plane & (1 << plane))) |
607 | continue; |
608 | |
609 | srcp += (ystart + 6) * src_stride; |
610 | dstp += ystart * dst_stride - 32; |
611 | srcpp = srcp - (ydia - 1) * src_stride - xdiad2m1; |
612 | |
613 | for (y = ystart; y < ystop; y += 2) { |
614 | for (x = 32; x < width - 32; x++) { |
615 | float mstd[4]; |
616 | |
617 | if (dstp[x] != 255) |
618 | continue; |
619 | |
620 | s->extract((const uint8_t *)(srcpp + x), src_stride, xdia, ydia, mstd, input); |
621 | for (i = 0; i < qual; i++) { |
622 | s->dot_prod(s, input, weights1[i], temp, nns * 2, asize, mstd + 2); |
623 | s->expfunc(temp, nns); |
624 | s->wae5(temp, nns, mstd); |
625 | } |
626 | |
627 | dstp[x] = FFMIN(FFMAX((int)(mstd[3] * scale + 0.5f), 0), s->max_value); |
628 | } |
629 | srcpp += src_stride * 2; |
630 | dstp += dst_stride * 2; |
631 | } |
632 | } |
633 | } |
634 | |
635 | #define NUM_NSIZE 7 |
636 | #define NUM_NNS 5 |
637 | |
638 | static int roundds(const double f) |
639 | { |
640 | if (f - floor(f) >= 0.5) |
641 | return FFMIN((int)ceil(f), 32767); |
642 | return FFMAX((int)floor(f), -32768); |
643 | } |
644 | |
645 | static void select_functions(NNEDIContext *s) |
646 | { |
647 | s->copy_pad = copy_pad; |
648 | s->evalfunc_0 = evalfunc_0; |
649 | s->evalfunc_1 = evalfunc_1; |
650 | |
651 | // evalfunc_0 |
652 | s->process_line0 = process_line0; |
653 | |
654 | if (s->pscrn < 2) { // original prescreener |
655 | if (s->fapprox & 1) { // int16 dot products |
656 | s->readpixels = byte2word48; |
657 | s->compute_network0 = compute_network0_i16; |
658 | } else { |
659 | s->readpixels = pixel2float48; |
660 | s->compute_network0 = compute_network0; |
661 | } |
662 | } else { // new prescreener |
663 | // only int16 dot products |
664 | s->readpixels = byte2word64; |
665 | s->compute_network0 = compute_network0new; |
666 | } |
667 | |
668 | // evalfunc_1 |
669 | s->wae5 = weighted_avg_elliott_mul5_m16; |
670 | |
671 | if (s->fapprox & 2) { // use int16 dot products |
672 | s->extract = extract_m8_i16; |
673 | s->dot_prod = dot_prods; |
674 | } else { // use float dot products |
675 | s->extract = extract_m8; |
676 | s->dot_prod = dot_prod; |
677 | } |
678 | |
679 | s->expfunc = e2_m16; |
680 | } |
681 | |
682 | static int modnpf(const int m, const int n) |
683 | { |
684 | if ((m % n) == 0) |
685 | return m; |
686 | return m + n - (m % n); |
687 | } |
688 | |
689 | static int get_frame(AVFilterContext *ctx, int is_second) |
690 | { |
691 | NNEDIContext *s = ctx->priv; |
692 | AVFilterLink *outlink = ctx->outputs[0]; |
693 | AVFrame *src = s->src; |
694 | FrameData *frame_data; |
695 | int effective_field = s->field; |
696 | size_t temp_size; |
697 | int field_n; |
698 | int plane; |
699 | |
700 | if (effective_field > 1) |
701 | effective_field -= 2; |
702 | else if (effective_field < 0) |
703 | effective_field += 2; |
704 | |
705 | if (s->field < 0 && src->interlaced_frame && src->top_field_first == 0) |
706 | effective_field = 0; |
707 | else if (s->field < 0 && src->interlaced_frame && src->top_field_first == 1) |
708 | effective_field = 1; |
709 | else |
710 | effective_field = !effective_field; |
711 | |
712 | if (s->field > 1 || s->field == -2) { |
713 | if (is_second) { |
714 | field_n = (effective_field == 0); |
715 | } else { |
716 | field_n = (effective_field == 1); |
717 | } |
718 | } else { |
719 | field_n = effective_field; |
720 | } |
721 | |
722 | s->dst = ff_get_video_buffer(outlink, outlink->w, outlink->h); |
723 | if (!s->dst) |
724 | return AVERROR(ENOMEM); |
725 | av_frame_copy_props(s->dst, src); |
726 | s->dst->interlaced_frame = 0; |
727 | |
728 | frame_data = &s->frame_data; |
729 | |
730 | for (plane = 0; plane < s->nb_planes; plane++) { |
731 | int dst_height = s->planeheight[plane]; |
732 | int dst_width = s->linesize[plane]; |
733 | |
734 | const int min_alignment = 16; |
735 | const int min_pad = 10; |
736 | |
737 | if (!(s->process_plane & (1 << plane))) { |
738 | av_image_copy_plane(s->dst->data[plane], s->dst->linesize[plane], |
739 | src->data[plane], src->linesize[plane], |
740 | s->linesize[plane], |
741 | s->planeheight[plane]); |
742 | continue; |
743 | } |
744 | |
745 | frame_data->padded_width[plane] = dst_width + 64; |
746 | frame_data->padded_height[plane] = dst_height + 12; |
747 | frame_data->padded_stride[plane] = modnpf(frame_data->padded_width[plane] + min_pad, min_alignment); // TODO: maybe min_pad is in pixels too? |
748 | if (!frame_data->paddedp[plane]) { |
749 | frame_data->paddedp[plane] = av_malloc_array(frame_data->padded_stride[plane], frame_data->padded_height[plane]); |
750 | if (!frame_data->paddedp[plane]) |
751 | return AVERROR(ENOMEM); |
752 | } |
753 | |
754 | frame_data->dstp[plane] = s->dst->data[plane]; |
755 | frame_data->dst_stride[plane] = s->dst->linesize[plane]; |
756 | |
757 | if (!frame_data->lcount[plane]) { |
758 | frame_data->lcount[plane] = av_calloc(dst_height, sizeof(int32_t) * 16); |
759 | if (!frame_data->lcount[plane]) |
760 | return AVERROR(ENOMEM); |
761 | } else { |
762 | memset(frame_data->lcount[plane], 0, dst_height * sizeof(int32_t) * 16); |
763 | } |
764 | |
765 | frame_data->field[plane] = field_n; |
766 | } |
767 | |
768 | if (!frame_data->input) { |
769 | frame_data->input = av_malloc(512 * sizeof(float)); |
770 | if (!frame_data->input) |
771 | return AVERROR(ENOMEM); |
772 | } |
773 | // evalfunc_0 requires at least padded_width[0] bytes. |
774 | // evalfunc_1 requires at least 512 floats. |
775 | if (!frame_data->temp) { |
776 | temp_size = FFMAX(frame_data->padded_width[0], 512 * sizeof(float)); |
777 | frame_data->temp = av_malloc(temp_size); |
778 | if (!frame_data->temp) |
779 | return AVERROR(ENOMEM); |
780 | } |
781 | |
782 | // Copy src to a padded "frame" in frame_data and mirror the edges. |
783 | s->copy_pad(src, frame_data, s, field_n); |
784 | |
785 | // Handles prescreening and the cubic interpolation. |
786 | s->evalfunc_0(s, frame_data); |
787 | |
788 | // The rest. |
789 | s->evalfunc_1(s, frame_data); |
790 | |
791 | return 0; |
792 | } |
793 | |
794 | static int filter_frame(AVFilterLink *inlink, AVFrame *src) |
795 | { |
796 | AVFilterContext *ctx = inlink->dst; |
797 | AVFilterLink *outlink = ctx->outputs[0]; |
798 | NNEDIContext *s = ctx->priv; |
799 | int ret; |
800 | |
801 | if ((s->field > 1 || |
802 | s->field == -2) && !s->second) { |
803 | goto second; |
804 | } else if (s->field > 1 || |
805 | s->field == -2) { |
806 | AVFrame *dst; |
807 | |
808 | s->src = s->second; |
809 | ret = get_frame(ctx, 1); |
810 | if (ret < 0) { |
811 | av_frame_free(&s->dst); |
812 | av_frame_free(&s->src); |
813 | av_frame_free(&s->second); |
814 | return ret; |
815 | } |
816 | dst = s->dst; |
817 | |
818 | if (src->pts != AV_NOPTS_VALUE && |
819 | dst->pts != AV_NOPTS_VALUE) |
820 | dst->pts += src->pts; |
821 | else |
822 | dst->pts = AV_NOPTS_VALUE; |
823 | |
824 | ret = ff_filter_frame(outlink, dst); |
825 | if (ret < 0) |
826 | return ret; |
827 | if (s->eof) |
828 | return 0; |
829 | s->cur_pts = s->second->pts; |
830 | av_frame_free(&s->second); |
831 | second: |
832 | if ((s->deint && src->interlaced_frame && |
833 | !ctx->is_disabled) || |
834 | (!s->deint && !ctx->is_disabled)) { |
835 | s->second = src; |
836 | } |
837 | } |
838 | |
839 | if ((s->deint && !src->interlaced_frame) || ctx->is_disabled) { |
840 | AVFrame *dst = av_frame_clone(src); |
841 | if (!dst) { |
842 | av_frame_free(&src); |
843 | av_frame_free(&s->second); |
844 | return AVERROR(ENOMEM); |
845 | } |
846 | |
847 | if (s->field > 1 || s->field == -2) { |
848 | av_frame_free(&s->second); |
849 | if ((s->deint && src->interlaced_frame) || |
850 | (!s->deint)) |
851 | s->second = src; |
852 | } else { |
853 | av_frame_free(&src); |
854 | } |
855 | if (dst->pts != AV_NOPTS_VALUE) |
856 | dst->pts *= 2; |
857 | return ff_filter_frame(outlink, dst); |
858 | } |
859 | |
860 | s->src = src; |
861 | ret = get_frame(ctx, 0); |
862 | if (ret < 0) { |
863 | av_frame_free(&s->dst); |
864 | av_frame_free(&s->src); |
865 | av_frame_free(&s->second); |
866 | return ret; |
867 | } |
868 | |
869 | if (src->pts != AV_NOPTS_VALUE) |
870 | s->dst->pts = src->pts * 2; |
871 | if (s->field <= 1 && s->field > -2) { |
872 | av_frame_free(&src); |
873 | s->src = NULL; |
874 | } |
875 | |
876 | return ff_filter_frame(outlink, s->dst); |
877 | } |
878 | |
879 | static int request_frame(AVFilterLink *link) |
880 | { |
881 | AVFilterContext *ctx = link->src; |
882 | NNEDIContext *s = ctx->priv; |
883 | int ret; |
884 | |
885 | if (s->eof) |
886 | return AVERROR_EOF; |
887 | |
888 | ret = ff_request_frame(ctx->inputs[0]); |
889 | |
890 | if (ret == AVERROR_EOF && s->second) { |
891 | AVFrame *next = av_frame_clone(s->second); |
892 | |
893 | if (!next) |
894 | return AVERROR(ENOMEM); |
895 | |
896 | next->pts = s->second->pts * 2 - s->cur_pts; |
897 | s->eof = 1; |
898 | |
899 | filter_frame(ctx->inputs[0], next); |
900 | } else if (ret < 0) { |
901 | return ret; |
902 | } |
903 | |
904 | return 0; |
905 | } |
906 | |
907 | static av_cold int init(AVFilterContext *ctx) |
908 | { |
909 | NNEDIContext *s = ctx->priv; |
910 | FILE *weights_file = NULL; |
911 | int64_t expected_size = 13574928; |
912 | int64_t weights_size; |
913 | float *bdata; |
914 | size_t bytes_read; |
915 | const int xdia_table[NUM_NSIZE] = { 8, 16, 32, 48, 8, 16, 32 }; |
916 | const int ydia_table[NUM_NSIZE] = { 6, 6, 6, 6, 4, 4, 4 }; |
917 | const int nns_table[NUM_NNS] = { 16, 32, 64, 128, 256 }; |
918 | const int dims0 = 49 * 4 + 5 * 4 + 9 * 4; |
919 | const int dims0new = 4 * 65 + 4 * 5; |
920 | const int dims1 = nns_table[s->nnsparam] * 2 * (xdia_table[s->nsize] * ydia_table[s->nsize] + 1); |
921 | int dims1tsize = 0; |
922 | int dims1offset = 0; |
923 | int ret = 0, i, j, k; |
924 | |
925 | weights_file = fopen(s->weights_file, "rb"); |
926 | if (!weights_file) { |
927 | av_log(ctx, AV_LOG_ERROR, "No weights file provided, aborting!\n"); |
928 | return AVERROR(EINVAL); |
929 | } |
930 | |
931 | if (fseek(weights_file, 0, SEEK_END)) { |
932 | av_log(ctx, AV_LOG_ERROR, "Couldn't seek to the end of weights file.\n"); |
933 | fclose(weights_file); |
934 | return AVERROR(EINVAL); |
935 | } |
936 | |
937 | weights_size = ftell(weights_file); |
938 | |
939 | if (weights_size == -1) { |
940 | fclose(weights_file); |
941 | av_log(ctx, AV_LOG_ERROR, "Couldn't get size of weights file.\n"); |
942 | return AVERROR(EINVAL); |
943 | } else if (weights_size != expected_size) { |
944 | fclose(weights_file); |
945 | av_log(ctx, AV_LOG_ERROR, "Unexpected weights file size.\n"); |
946 | return AVERROR(EINVAL); |
947 | } |
948 | |
949 | if (fseek(weights_file, 0, SEEK_SET)) { |
950 | fclose(weights_file); |
951 | av_log(ctx, AV_LOG_ERROR, "Couldn't seek to the start of weights file.\n"); |
952 | return AVERROR(EINVAL); |
953 | } |
954 | |
955 | bdata = (float *)av_malloc(expected_size); |
956 | if (!bdata) { |
957 | fclose(weights_file); |
958 | return AVERROR(ENOMEM); |
959 | } |
960 | |
961 | bytes_read = fread(bdata, 1, expected_size, weights_file); |
962 | |
963 | if (bytes_read != (size_t)expected_size) { |
964 | fclose(weights_file); |
965 | ret = AVERROR_INVALIDDATA; |
966 | av_log(ctx, AV_LOG_ERROR, "Couldn't read weights file.\n"); |
967 | goto fail; |
968 | } |
969 | |
970 | fclose(weights_file); |
971 | |
972 | for (j = 0; j < NUM_NNS; j++) { |
973 | for (i = 0; i < NUM_NSIZE; i++) { |
974 | if (i == s->nsize && j == s->nnsparam) |
975 | dims1offset = dims1tsize; |
976 | dims1tsize += nns_table[j] * 2 * (xdia_table[i] * ydia_table[i] + 1) * 2; |
977 | } |
978 | } |
979 | |
980 | s->weights0 = av_malloc_array(FFMAX(dims0, dims0new), sizeof(float)); |
981 | if (!s->weights0) { |
982 | ret = AVERROR(ENOMEM); |
983 | goto fail; |
984 | } |
985 | |
986 | for (i = 0; i < 2; i++) { |
987 | s->weights1[i] = av_malloc_array(dims1, sizeof(float)); |
988 | if (!s->weights1[i]) { |
989 | ret = AVERROR(ENOMEM); |
990 | goto fail; |
991 | } |
992 | } |
993 | |
994 | // Adjust prescreener weights |
995 | if (s->pscrn >= 2) {// using new prescreener |
996 | const float *bdw; |
997 | int16_t *ws; |
998 | float *wf; |
999 | double mean[4] = { 0.0, 0.0, 0.0, 0.0 }; |
1000 | int *offt = av_calloc(4 * 64, sizeof(int)); |
1001 | |
1002 | if (!offt) { |
1003 | ret = AVERROR(ENOMEM); |
1004 | goto fail; |
1005 | } |
1006 | |
1007 | for (j = 0; j < 4; j++) |
1008 | for (k = 0; k < 64; k++) |
1009 | offt[j * 64 + k] = ((k >> 3) << 5) + ((j & 3) << 3) + (k & 7); |
1010 | |
1011 | bdw = bdata + dims0 + dims0new * (s->pscrn - 2); |
1012 | ws = (int16_t *)s->weights0; |
1013 | wf = (float *)&ws[4 * 64]; |
1014 | // Calculate mean weight of each first layer neuron |
1015 | for (j = 0; j < 4; j++) { |
1016 | double cmean = 0.0; |
1017 | for (k = 0; k < 64; k++) |
1018 | cmean += bdw[offt[j * 64 + k]]; |
1019 | mean[j] = cmean / 64.0; |
1020 | } |
1021 | // Factor mean removal and 1.0/127.5 scaling |
1022 | // into first layer weights. scale to int16 range |
1023 | for (j = 0; j < 4; j++) { |
1024 | double scale, mval = 0.0; |
1025 | |
1026 | for (k = 0; k < 64; k++) |
1027 | mval = FFMAX(mval, FFABS((bdw[offt[j * 64 + k]] - mean[j]) / 127.5)); |
1028 | scale = 32767.0 / mval; |
1029 | for (k = 0; k < 64; k++) |
1030 | ws[offt[j * 64 + k]] = roundds(((bdw[offt[j * 64 + k]] - mean[j]) / 127.5) * scale); |
1031 | wf[j] = (float)(mval / 32767.0); |
1032 | } |
1033 | memcpy(wf + 4, bdw + 4 * 64, (dims0new - 4 * 64) * sizeof(float)); |
1034 | av_free(offt); |
1035 | } else { // using old prescreener |
1036 | double mean[4] = { 0.0, 0.0, 0.0, 0.0 }; |
1037 | // Calculate mean weight of each first layer neuron |
1038 | for (j = 0; j < 4; j++) { |
1039 | double cmean = 0.0; |
1040 | for (k = 0; k < 48; k++) |
1041 | cmean += bdata[j * 48 + k]; |
1042 | mean[j] = cmean / 48.0; |
1043 | } |
1044 | if (s->fapprox & 1) {// use int16 dot products in first layer |
1045 | int16_t *ws = (int16_t *)s->weights0; |
1046 | float *wf = (float *)&ws[4 * 48]; |
1047 | // Factor mean removal and 1.0/127.5 scaling |
1048 | // into first layer weights. scale to int16 range |
1049 | for (j = 0; j < 4; j++) { |
1050 | double scale, mval = 0.0; |
1051 | for (k = 0; k < 48; k++) |
1052 | mval = FFMAX(mval, FFABS((bdata[j * 48 + k] - mean[j]) / 127.5)); |
1053 | scale = 32767.0 / mval; |
1054 | for (k = 0; k < 48; k++) |
1055 | ws[j * 48 + k] = roundds(((bdata[j * 48 + k] - mean[j]) / 127.5) * scale); |
1056 | wf[j] = (float)(mval / 32767.0); |
1057 | } |
1058 | memcpy(wf + 4, bdata + 4 * 48, (dims0 - 4 * 48) * sizeof(float)); |
1059 | } else {// use float dot products in first layer |
1060 | double half = (1 << 8) - 1; |
1061 | |
1062 | half /= 2; |
1063 | |
1064 | // Factor mean removal and 1.0/half scaling |
1065 | // into first layer weights. |
1066 | for (j = 0; j < 4; j++) |
1067 | for (k = 0; k < 48; k++) |
1068 | s->weights0[j * 48 + k] = (float)((bdata[j * 48 + k] - mean[j]) / half); |
1069 | memcpy(s->weights0 + 4 * 48, bdata + 4 * 48, (dims0 - 4 * 48) * sizeof(float)); |
1070 | } |
1071 | } |
1072 | |
1073 | // Adjust prediction weights |
1074 | for (i = 0; i < 2; i++) { |
1075 | const float *bdataT = bdata + dims0 + dims0new * 3 + dims1tsize * s->etype + dims1offset + i * dims1; |
1076 | const int nnst = nns_table[s->nnsparam]; |
1077 | const int asize = xdia_table[s->nsize] * ydia_table[s->nsize]; |
1078 | const int boff = nnst * 2 * asize; |
1079 | double *mean = (double *)av_calloc(asize + 1 + nnst * 2, sizeof(double)); |
1080 | |
1081 | if (!mean) { |
1082 | ret = AVERROR(ENOMEM); |
1083 | goto fail; |
1084 | } |
1085 | |
1086 | // Calculate mean weight of each neuron (ignore bias) |
1087 | for (j = 0; j < nnst * 2; j++) { |
1088 | double cmean = 0.0; |
1089 | for (k = 0; k < asize; k++) |
1090 | cmean += bdataT[j * asize + k]; |
1091 | mean[asize + 1 + j] = cmean / (double)asize; |
1092 | } |
1093 | // Calculate mean softmax neuron |
1094 | for (j = 0; j < nnst; j++) { |
1095 | for (k = 0; k < asize; k++) |
1096 | mean[k] += bdataT[j * asize + k] - mean[asize + 1 + j]; |
1097 | mean[asize] += bdataT[boff + j]; |
1098 | } |
1099 | for (j = 0; j < asize + 1; j++) |
1100 | mean[j] /= (double)(nnst); |
1101 | |
1102 | if (s->fapprox & 2) { // use int16 dot products |
1103 | int16_t *ws = (int16_t *)s->weights1[i]; |
1104 | float *wf = (float *)&ws[nnst * 2 * asize]; |
1105 | // Factor mean removal into weights, remove global offset from |
1106 | // softmax neurons, and scale weights to int16 range. |
1107 | for (j = 0; j < nnst; j++) { // softmax neurons |
1108 | double scale, mval = 0.0; |
1109 | for (k = 0; k < asize; k++) |
1110 | mval = FFMAX(mval, FFABS(bdataT[j * asize + k] - mean[asize + 1 + j] - mean[k])); |
1111 | scale = 32767.0 / mval; |
1112 | for (k = 0; k < asize; k++) |
1113 | ws[j * asize + k] = roundds((bdataT[j * asize + k] - mean[asize + 1 + j] - mean[k]) * scale); |
1114 | wf[(j >> 2) * 8 + (j & 3)] = (float)(mval / 32767.0); |
1115 | wf[(j >> 2) * 8 + (j & 3) + 4] = (float)(bdataT[boff + j] - mean[asize]); |
1116 | } |
1117 | for (j = nnst; j < nnst * 2; j++) { // elliott neurons |
1118 | double scale, mval = 0.0; |
1119 | for (k = 0; k < asize; k++) |
1120 | mval = FFMAX(mval, FFABS(bdataT[j * asize + k] - mean[asize + 1 + j])); |
1121 | scale = 32767.0 / mval; |
1122 | for (k = 0; k < asize; k++) |
1123 | ws[j * asize + k] = roundds((bdataT[j * asize + k] - mean[asize + 1 + j]) * scale); |
1124 | wf[(j >> 2) * 8 + (j & 3)] = (float)(mval / 32767.0); |
1125 | wf[(j >> 2) * 8 + (j & 3) + 4] = bdataT[boff + j]; |
1126 | } |
1127 | } else { // use float dot products |
1128 | // Factor mean removal into weights, and remove global |
1129 | // offset from softmax neurons. |
1130 | for (j = 0; j < nnst * 2; j++) { |
1131 | for (k = 0; k < asize; k++) { |
1132 | const double q = j < nnst ? mean[k] : 0.0; |
1133 | s->weights1[i][j * asize + k] = (float)(bdataT[j * asize + k] - mean[asize + 1 + j] - q); |
1134 | } |
1135 | s->weights1[i][boff + j] = (float)(bdataT[boff + j] - (j < nnst ? mean[asize] : 0.0)); |
1136 | } |
1137 | } |
1138 | av_free(mean); |
1139 | } |
1140 | |
1141 | s->nns = nns_table[s->nnsparam]; |
1142 | s->xdia = xdia_table[s->nsize]; |
1143 | s->ydia = ydia_table[s->nsize]; |
1144 | s->asize = xdia_table[s->nsize] * ydia_table[s->nsize]; |
1145 | |
1146 | s->max_value = 65535 >> 8; |
1147 | |
1148 | select_functions(s); |
1149 | |
1150 | s->fdsp = avpriv_float_dsp_alloc(0); |
1151 | if (!s->fdsp) |
1152 | ret = AVERROR(ENOMEM); |
1153 | |
1154 | fail: |
1155 | av_free(bdata); |
1156 | return ret; |
1157 | } |
1158 | |
1159 | static av_cold void uninit(AVFilterContext *ctx) |
1160 | { |
1161 | NNEDIContext *s = ctx->priv; |
1162 | int i; |
1163 | |
1164 | av_freep(&s->weights0); |
1165 | |
1166 | for (i = 0; i < 2; i++) |
1167 | av_freep(&s->weights1[i]); |
1168 | |
1169 | for (i = 0; i < s->nb_planes; i++) { |
1170 | av_freep(&s->frame_data.paddedp[i]); |
1171 | av_freep(&s->frame_data.lcount[i]); |
1172 | } |
1173 | |
1174 | av_freep(&s->frame_data.input); |
1175 | av_freep(&s->frame_data.temp); |
1176 | av_freep(&s->fdsp); |
1177 | av_frame_free(&s->second); |
1178 | } |
1179 | |
1180 | static const AVFilterPad inputs[] = { |
1181 | { |
1182 | .name = "default", |
1183 | .type = AVMEDIA_TYPE_VIDEO, |
1184 | .filter_frame = filter_frame, |
1185 | .config_props = config_input, |
1186 | }, |
1187 | { NULL } |
1188 | }; |
1189 | |
1190 | static const AVFilterPad outputs[] = { |
1191 | { |
1192 | .name = "default", |
1193 | .type = AVMEDIA_TYPE_VIDEO, |
1194 | .config_props = config_output, |
1195 | .request_frame = request_frame, |
1196 | }, |
1197 | { NULL } |
1198 | }; |
1199 | |
1200 | AVFilter ff_vf_nnedi = { |
1201 | .name = "nnedi", |
1202 | .description = NULL_IF_CONFIG_SMALL("Apply neural network edge directed interpolation intra-only deinterlacer."), |
1203 | .priv_size = sizeof(NNEDIContext), |
1204 | .priv_class = &nnedi_class, |
1205 | .init = init, |
1206 | .uninit = uninit, |
1207 | .query_formats = query_formats, |
1208 | .inputs = inputs, |
1209 | .outputs = outputs, |
1210 | .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL, |
1211 | }; |
1212 |