Import zstd 1.1.4
[freebsd.git] / contrib / pzstd / Pzstd.cpp
1 /**
2  * Copyright (c) 2016-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree. An additional grant
7  * of patent rights can be found in the PATENTS file in the same directory.
8  */
9 #include "Pzstd.h"
10 #include "SkippableFrame.h"
11 #include "utils/FileSystem.h"
12 #include "utils/Range.h"
13 #include "utils/ScopeGuard.h"
14 #include "utils/ThreadPool.h"
15 #include "utils/WorkQueue.h"
16
17 #include <chrono>
18 #include <cinttypes>
19 #include <cstddef>
20 #include <cstdio>
21 #include <memory>
22 #include <string>
23
24 #if defined(MSDOS) || defined(OS2) || defined(WIN32) || defined(_WIN32) || defined(__CYGWIN__)
25 #  include <fcntl.h>    /* _O_BINARY */
26 #  include <io.h>       /* _setmode, _isatty */
27 #  define SET_BINARY_MODE(file) { if (_setmode(_fileno(file), _O_BINARY) == -1) perror("Cannot set _O_BINARY"); }
28 #else
29 #  include <unistd.h>   /* isatty */
30 #  define SET_BINARY_MODE(file)
31 #endif
32
33 namespace pzstd {
34
35 namespace {
36 #ifdef _WIN32
37 const std::string nullOutput = "nul";
38 #else
39 const std::string nullOutput = "/dev/null";
40 #endif
41 }
42
43 using std::size_t;
44
45 static std::uintmax_t fileSizeOrZero(const std::string &file) {
46   if (file == "-") {
47     return 0;
48   }
49   std::error_code ec;
50   auto size = file_size(file, ec);
51   if (ec) {
52     size = 0;
53   }
54   return size;
55 }
56
57 static std::uint64_t handleOneInput(const Options &options,
58                              const std::string &inputFile,
59                              FILE* inputFd,
60                              const std::string &outputFile,
61                              FILE* outputFd,
62                              SharedState& state) {
63   auto inputSize = fileSizeOrZero(inputFile);
64   // WorkQueue outlives ThreadPool so in the case of error we are certain
65   // we don't accidently try to call push() on it after it is destroyed
66   WorkQueue<std::shared_ptr<BufferWorkQueue>> outs{options.numThreads + 1};
67   std::uint64_t bytesRead;
68   std::uint64_t bytesWritten;
69   {
70     // Initialize the (de)compression thread pool with numThreads
71     ThreadPool executor(options.numThreads);
72     // Run the reader thread on an extra thread
73     ThreadPool readExecutor(1);
74     if (!options.decompress) {
75       // Add a job that reads the input and starts all the compression jobs
76       readExecutor.add(
77           [&state, &outs, &executor, inputFd, inputSize, &options, &bytesRead] {
78             bytesRead = asyncCompressChunks(
79                 state,
80                 outs,
81                 executor,
82                 inputFd,
83                 inputSize,
84                 options.numThreads,
85                 options.determineParameters());
86           });
87       // Start writing
88       bytesWritten = writeFile(state, outs, outputFd, options.decompress);
89     } else {
90       // Add a job that reads the input and starts all the decompression jobs
91       readExecutor.add([&state, &outs, &executor, inputFd, &bytesRead] {
92         bytesRead = asyncDecompressFrames(state, outs, executor, inputFd);
93       });
94       // Start writing
95       bytesWritten = writeFile(state, outs, outputFd, options.decompress);
96     }
97   }
98   if (!state.errorHolder.hasError()) {
99     std::string inputFileName = inputFile == "-" ? "stdin" : inputFile;
100     std::string outputFileName = outputFile == "-" ? "stdout" : outputFile;
101     if (!options.decompress) {
102       double ratio = static_cast<double>(bytesWritten) /
103                      static_cast<double>(bytesRead + !bytesRead);
104       state.log(INFO, "%-20s :%6.2f%%   (%6" PRIu64 " => %6" PRIu64
105                    " bytes, %s)\n",
106                    inputFileName.c_str(), ratio * 100, bytesRead, bytesWritten,
107                    outputFileName.c_str());
108     } else {
109       state.log(INFO, "%-20s: %" PRIu64 " bytes \n",
110                    inputFileName.c_str(),bytesWritten);
111     }
112   }
113   return bytesWritten;
114 }
115
116 static FILE *openInputFile(const std::string &inputFile,
117                            ErrorHolder &errorHolder) {
118   if (inputFile == "-") {
119     SET_BINARY_MODE(stdin);
120     return stdin;
121   }
122   // Check if input file is a directory
123   {
124     std::error_code ec;
125     if (is_directory(inputFile, ec)) {
126       errorHolder.setError("Output file is a directory -- ignored");
127       return nullptr;
128     }
129   }
130   auto inputFd = std::fopen(inputFile.c_str(), "rb");
131   if (!errorHolder.check(inputFd != nullptr, "Failed to open input file")) {
132     return nullptr;
133   }
134   return inputFd;
135 }
136
137 static FILE *openOutputFile(const Options &options,
138                             const std::string &outputFile,
139                             SharedState& state) {
140   if (outputFile == "-") {
141     SET_BINARY_MODE(stdout);
142     return stdout;
143   }
144   // Check if the output file exists and then open it
145   if (!options.overwrite && outputFile != nullOutput) {
146     auto outputFd = std::fopen(outputFile.c_str(), "rb");
147     if (outputFd != nullptr) {
148       std::fclose(outputFd);
149       if (!state.log.logsAt(INFO)) {
150         state.errorHolder.setError("Output file exists");
151         return nullptr;
152       }
153       state.log(
154           INFO,
155           "pzstd: %s already exists; do you wish to overwrite (y/n) ? ",
156           outputFile.c_str());
157       int c = getchar();
158       if (c != 'y' && c != 'Y') {
159         state.errorHolder.setError("Not overwritten");
160         return nullptr;
161       }
162     }
163   }
164   auto outputFd = std::fopen(outputFile.c_str(), "wb");
165   if (!state.errorHolder.check(
166           outputFd != nullptr, "Failed to open output file")) {
167     return nullptr;
168   }
169   return outputFd;
170 }
171
172 int pzstdMain(const Options &options) {
173   int returnCode = 0;
174   SharedState state(options);
175   for (const auto& input : options.inputFiles) {
176     // Setup the shared state
177     auto printErrorGuard = makeScopeGuard([&] {
178       if (state.errorHolder.hasError()) {
179         returnCode = 1;
180         state.log(ERROR, "pzstd: %s: %s.\n", input.c_str(),
181                   state.errorHolder.getError().c_str());
182       }
183     });
184     // Open the input file
185     auto inputFd = openInputFile(input, state.errorHolder);
186     if (inputFd == nullptr) {
187       continue;
188     }
189     auto closeInputGuard = makeScopeGuard([&] { std::fclose(inputFd); });
190     // Open the output file
191     auto outputFile = options.getOutputFile(input);
192     if (!state.errorHolder.check(outputFile != "",
193                            "Input file does not have extension .zst")) {
194       continue;
195     }
196     auto outputFd = openOutputFile(options, outputFile, state);
197     if (outputFd == nullptr) {
198       continue;
199     }
200     auto closeOutputGuard = makeScopeGuard([&] { std::fclose(outputFd); });
201     // (de)compress the file
202     handleOneInput(options, input, inputFd, outputFile, outputFd, state);
203     if (state.errorHolder.hasError()) {
204       continue;
205     }
206     // Delete the input file if necessary
207     if (!options.keepSource) {
208       // Be sure that we are done and have written everything before we delete
209       if (!state.errorHolder.check(std::fclose(inputFd) == 0,
210                              "Failed to close input file")) {
211         continue;
212       }
213       closeInputGuard.dismiss();
214       if (!state.errorHolder.check(std::fclose(outputFd) == 0,
215                              "Failed to close output file")) {
216         continue;
217       }
218       closeOutputGuard.dismiss();
219       if (std::remove(input.c_str()) != 0) {
220         state.errorHolder.setError("Failed to remove input file");
221         continue;
222       }
223     }
224   }
225   // Returns 1 if any of the files failed to (de)compress.
226   return returnCode;
227 }
228
229 /// Construct a `ZSTD_inBuffer` that points to the data in `buffer`.
230 static ZSTD_inBuffer makeZstdInBuffer(const Buffer& buffer) {
231   return ZSTD_inBuffer{buffer.data(), buffer.size(), 0};
232 }
233
234 /**
235  * Advance `buffer` and `inBuffer` by the amount of data read, as indicated by
236  * `inBuffer.pos`.
237  */
238 void advance(Buffer& buffer, ZSTD_inBuffer& inBuffer) {
239   auto pos = inBuffer.pos;
240   inBuffer.src = static_cast<const unsigned char*>(inBuffer.src) + pos;
241   inBuffer.size -= pos;
242   inBuffer.pos = 0;
243   return buffer.advance(pos);
244 }
245
246 /// Construct a `ZSTD_outBuffer` that points to the data in `buffer`.
247 static ZSTD_outBuffer makeZstdOutBuffer(Buffer& buffer) {
248   return ZSTD_outBuffer{buffer.data(), buffer.size(), 0};
249 }
250
251 /**
252  * Split `buffer` and advance `outBuffer` by the amount of data written, as
253  * indicated by `outBuffer.pos`.
254  */
255 Buffer split(Buffer& buffer, ZSTD_outBuffer& outBuffer) {
256   auto pos = outBuffer.pos;
257   outBuffer.dst = static_cast<unsigned char*>(outBuffer.dst) + pos;
258   outBuffer.size -= pos;
259   outBuffer.pos = 0;
260   return buffer.splitAt(pos);
261 }
262
263 /**
264  * Stream chunks of input from `in`, compress it, and stream it out to `out`.
265  *
266  * @param state        The shared state
267  * @param in           Queue that we `pop()` input buffers from
268  * @param out          Queue that we `push()` compressed output buffers to
269  * @param maxInputSize An upper bound on the size of the input
270  */
271 static void compress(
272     SharedState& state,
273     std::shared_ptr<BufferWorkQueue> in,
274     std::shared_ptr<BufferWorkQueue> out,
275     size_t maxInputSize) {
276   auto& errorHolder = state.errorHolder;
277   auto guard = makeScopeGuard([&] { out->finish(); });
278   // Initialize the CCtx
279   auto ctx = state.cStreamPool->get();
280   if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_CStream")) {
281     return;
282   }
283   {
284     auto err = ZSTD_resetCStream(ctx.get(), 0);
285     if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
286       return;
287     }
288   }
289
290   // Allocate space for the result
291   auto outBuffer = Buffer(ZSTD_compressBound(maxInputSize));
292   auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
293   {
294     Buffer inBuffer;
295     // Read a buffer in from the input queue
296     while (in->pop(inBuffer) && !errorHolder.hasError()) {
297       auto zstdInBuffer = makeZstdInBuffer(inBuffer);
298       // Compress the whole buffer and send it to the output queue
299       while (!inBuffer.empty() && !errorHolder.hasError()) {
300         if (!errorHolder.check(
301                 !outBuffer.empty(), "ZSTD_compressBound() was too small")) {
302           return;
303         }
304         // Compress
305         auto err =
306             ZSTD_compressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
307         if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
308           return;
309         }
310         // Split the compressed data off outBuffer and pass to the output queue
311         out->push(split(outBuffer, zstdOutBuffer));
312         // Forget about the data we already compressed
313         advance(inBuffer, zstdInBuffer);
314       }
315     }
316   }
317   // Write the epilog
318   size_t bytesLeft;
319   do {
320     if (!errorHolder.check(
321             !outBuffer.empty(), "ZSTD_compressBound() was too small")) {
322       return;
323     }
324     bytesLeft = ZSTD_endStream(ctx.get(), &zstdOutBuffer);
325     if (!errorHolder.check(
326             !ZSTD_isError(bytesLeft), ZSTD_getErrorName(bytesLeft))) {
327       return;
328     }
329     out->push(split(outBuffer, zstdOutBuffer));
330   } while (bytesLeft != 0 && !errorHolder.hasError());
331 }
332
333 /**
334  * Calculates how large each independently compressed frame should be.
335  *
336  * @param size       The size of the source if known, 0 otherwise
337  * @param numThreads The number of threads available to run compression jobs on
338  * @param params     The zstd parameters to be used for compression
339  */
340 static size_t calculateStep(
341     std::uintmax_t size,
342     size_t numThreads,
343     const ZSTD_parameters &params) {
344   (void)size;
345   (void)numThreads;
346   return size_t{1} << (params.cParams.windowLog + 2);
347 }
348
349 namespace {
350 enum class FileStatus { Continue, Done, Error };
351 /// Determines the status of the file descriptor `fd`.
352 FileStatus fileStatus(FILE* fd) {
353   if (std::feof(fd)) {
354     return FileStatus::Done;
355   } else if (std::ferror(fd)) {
356     return FileStatus::Error;
357   }
358   return FileStatus::Continue;
359 }
360 } // anonymous namespace
361
362 /**
363  * Reads `size` data in chunks of `chunkSize` and puts it into `queue`.
364  * Will read less if an error or EOF occurs.
365  * Returns the status of the file after all of the reads have occurred.
366  */
367 static FileStatus
368 readData(BufferWorkQueue& queue, size_t chunkSize, size_t size, FILE* fd,
369          std::uint64_t *totalBytesRead) {
370   Buffer buffer(size);
371   while (!buffer.empty()) {
372     auto bytesRead =
373         std::fread(buffer.data(), 1, std::min(chunkSize, buffer.size()), fd);
374     *totalBytesRead += bytesRead;
375     queue.push(buffer.splitAt(bytesRead));
376     auto status = fileStatus(fd);
377     if (status != FileStatus::Continue) {
378       return status;
379     }
380   }
381   return FileStatus::Continue;
382 }
383
384 std::uint64_t asyncCompressChunks(
385     SharedState& state,
386     WorkQueue<std::shared_ptr<BufferWorkQueue>>& chunks,
387     ThreadPool& executor,
388     FILE* fd,
389     std::uintmax_t size,
390     size_t numThreads,
391     ZSTD_parameters params) {
392   auto chunksGuard = makeScopeGuard([&] { chunks.finish(); });
393   std::uint64_t bytesRead = 0;
394
395   // Break the input up into chunks of size `step` and compress each chunk
396   // independently.
397   size_t step = calculateStep(size, numThreads, params);
398   state.log(DEBUG, "Chosen frame size: %zu\n", step);
399   auto status = FileStatus::Continue;
400   while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
401     // Make a new input queue that we will put the chunk's input data into.
402     auto in = std::make_shared<BufferWorkQueue>();
403     auto inGuard = makeScopeGuard([&] { in->finish(); });
404     // Make a new output queue that compress will put the compressed data into.
405     auto out = std::make_shared<BufferWorkQueue>();
406     // Start compression in the thread pool
407     executor.add([&state, in, out, step] {
408       return compress(
409           state, std::move(in), std::move(out), step);
410     });
411     // Pass the output queue to the writer thread.
412     chunks.push(std::move(out));
413     state.log(VERBOSE, "%s\n", "Starting a new frame");
414     // Fill the input queue for the compression job we just started
415     status = readData(*in, ZSTD_CStreamInSize(), step, fd, &bytesRead);
416   }
417   state.errorHolder.check(status != FileStatus::Error, "Error reading input");
418   return bytesRead;
419 }
420
421 /**
422  * Decompress a frame, whose data is streamed into `in`, and stream the output
423  * to `out`.
424  *
425  * @param state        The shared state
426  * @param in           Queue that we `pop()` input buffers from. It contains
427  *                      exactly one compressed frame.
428  * @param out          Queue that we `push()` decompressed output buffers to
429  */
430 static void decompress(
431     SharedState& state,
432     std::shared_ptr<BufferWorkQueue> in,
433     std::shared_ptr<BufferWorkQueue> out) {
434   auto& errorHolder = state.errorHolder;
435   auto guard = makeScopeGuard([&] { out->finish(); });
436   // Initialize the DCtx
437   auto ctx = state.dStreamPool->get();
438   if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_DStream")) {
439     return;
440   }
441   {
442     auto err = ZSTD_resetDStream(ctx.get());
443     if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
444       return;
445     }
446   }
447
448   const size_t outSize = ZSTD_DStreamOutSize();
449   Buffer inBuffer;
450   size_t returnCode = 0;
451   // Read a buffer in from the input queue
452   while (in->pop(inBuffer) && !errorHolder.hasError()) {
453     auto zstdInBuffer = makeZstdInBuffer(inBuffer);
454     // Decompress the whole buffer and send it to the output queue
455     while (!inBuffer.empty() && !errorHolder.hasError()) {
456       // Allocate a buffer with at least outSize bytes.
457       Buffer outBuffer(outSize);
458       auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
459       // Decompress
460       returnCode =
461           ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
462       if (!errorHolder.check(
463               !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
464         return;
465       }
466       // Pass the buffer with the decompressed data to the output queue
467       out->push(split(outBuffer, zstdOutBuffer));
468       // Advance past the input we already read
469       advance(inBuffer, zstdInBuffer);
470       if (returnCode == 0) {
471         // The frame is over, prepare to (maybe) start a new frame
472         ZSTD_initDStream(ctx.get());
473       }
474     }
475   }
476   if (!errorHolder.check(returnCode <= 1, "Incomplete block")) {
477     return;
478   }
479   // We've given ZSTD_decompressStream all of our data, but there may still
480   // be data to read.
481   while (returnCode == 1) {
482     // Allocate a buffer with at least outSize bytes.
483     Buffer outBuffer(outSize);
484     auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
485     // Pass in no input.
486     ZSTD_inBuffer zstdInBuffer{nullptr, 0, 0};
487     // Decompress
488     returnCode =
489         ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
490     if (!errorHolder.check(
491             !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
492       return;
493     }
494     // Pass the buffer with the decompressed data to the output queue
495     out->push(split(outBuffer, zstdOutBuffer));
496   }
497 }
498
499 std::uint64_t asyncDecompressFrames(
500     SharedState& state,
501     WorkQueue<std::shared_ptr<BufferWorkQueue>>& frames,
502     ThreadPool& executor,
503     FILE* fd) {
504   auto framesGuard = makeScopeGuard([&] { frames.finish(); });
505   std::uint64_t totalBytesRead = 0;
506
507   // Split the source up into its component frames.
508   // If we find our recognized skippable frame we know the next frames size
509   // which means that we can decompress each standard frame in independently.
510   // Otherwise, we will decompress using only one decompression task.
511   const size_t chunkSize = ZSTD_DStreamInSize();
512   auto status = FileStatus::Continue;
513   while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
514     // Make a new input queue that we will put the frames's bytes into.
515     auto in = std::make_shared<BufferWorkQueue>();
516     auto inGuard = makeScopeGuard([&] { in->finish(); });
517     // Make a output queue that decompress will put the decompressed data into
518     auto out = std::make_shared<BufferWorkQueue>();
519
520     size_t frameSize;
521     {
522       // Calculate the size of the next frame.
523       // frameSize is 0 if the frame info can't be decoded.
524       Buffer buffer(SkippableFrame::kSize);
525       auto bytesRead = std::fread(buffer.data(), 1, buffer.size(), fd);
526       totalBytesRead += bytesRead;
527       status = fileStatus(fd);
528       if (bytesRead == 0 && status != FileStatus::Continue) {
529         break;
530       }
531       buffer.subtract(buffer.size() - bytesRead);
532       frameSize = SkippableFrame::tryRead(buffer.range());
533       in->push(std::move(buffer));
534     }
535     if (frameSize == 0) {
536       // We hit a non SkippableFrame, so this will be the last job.
537       // Make sure that we don't use too much memory
538       in->setMaxSize(64);
539       out->setMaxSize(64);
540     }
541     // Start decompression in the thread pool
542     executor.add([&state, in, out] {
543       return decompress(state, std::move(in), std::move(out));
544     });
545     // Pass the output queue to the writer thread
546     frames.push(std::move(out));
547     if (frameSize == 0) {
548       // We hit a non SkippableFrame ==> not compressed by pzstd or corrupted
549       // Pass the rest of the source to this decompression task
550       state.log(VERBOSE, "%s\n",
551           "Input not in pzstd format, falling back to serial decompression");
552       while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
553         status = readData(*in, chunkSize, chunkSize, fd, &totalBytesRead);
554       }
555       break;
556     }
557     state.log(VERBOSE, "Decompressing a frame of size %zu", frameSize);
558     // Fill the input queue for the decompression job we just started
559     status = readData(*in, chunkSize, frameSize, fd, &totalBytesRead);
560   }
561   state.errorHolder.check(status != FileStatus::Error, "Error reading input");
562   return totalBytesRead;
563 }
564
565 /// Write `data` to `fd`, returns true iff success.
566 static bool writeData(ByteRange data, FILE* fd) {
567   while (!data.empty()) {
568     data.advance(std::fwrite(data.begin(), 1, data.size(), fd));
569     if (std::ferror(fd)) {
570       return false;
571     }
572   }
573   return true;
574 }
575
576 std::uint64_t writeFile(
577     SharedState& state,
578     WorkQueue<std::shared_ptr<BufferWorkQueue>>& outs,
579     FILE* outputFd,
580     bool decompress) {
581   auto& errorHolder = state.errorHolder;
582   auto lineClearGuard = makeScopeGuard([&state] {
583     state.log.clear(INFO);
584   });
585   std::uint64_t bytesWritten = 0;
586   std::shared_ptr<BufferWorkQueue> out;
587   // Grab the output queue for each decompression job (in order).
588   while (outs.pop(out) && !errorHolder.hasError()) {
589     if (!decompress) {
590       // If we are compressing and want to write skippable frames we can't
591       // start writing before compression is done because we need to know the
592       // compressed size.
593       // Wait for the compressed size to be available and write skippable frame
594       SkippableFrame frame(out->size());
595       if (!writeData(frame.data(), outputFd)) {
596         errorHolder.setError("Failed to write output");
597         return bytesWritten;
598       }
599       bytesWritten += frame.kSize;
600     }
601     // For each chunk of the frame: Pop it from the queue and write it
602     Buffer buffer;
603     while (out->pop(buffer) && !errorHolder.hasError()) {
604       if (!writeData(buffer.range(), outputFd)) {
605         errorHolder.setError("Failed to write output");
606         return bytesWritten;
607       }
608       bytesWritten += buffer.size();
609       state.log.update(INFO, "Written: %u MB   ",
610                 static_cast<std::uint32_t>(bytesWritten >> 20));
611     }
612   }
613   return bytesWritten;
614 }
615 }