Sync rsh(1) with FreeBSD.
[dragonfly.git] / usr.bin / rsh / rsh.c
1 /*-
2  * Copyright (c) 1983, 1990, 1993, 1994
3  *      The Regents of the University of California.  All rights reserved.
4  * Copyright (c) 2002 Networks Associates Technology, Inc.
5  * All rights reserved.
6  *
7  * Portions of this software were developed for the FreeBSD Project by
8  * ThinkSec AS and NAI Labs, the Security Research Division of Network
9  * Associates, Inc.  under DARPA/SPAWAR contract N66001-01-C-8035
10  * ("CBOSS"), as part of the DARPA CHATS research program.
11  *
12  * Redistribution and use in source and binary forms, with or without
13  * modification, are permitted provided that the following conditions
14  * are met:
15  * 1. Redistributions of source code must retain the above copyright
16  *    notice, this list of conditions and the following disclaimer.
17  * 2. Redistributions in binary form must reproduce the above copyright
18  *    notice, this list of conditions and the following disclaimer in the
19  *    documentation and/or other materials provided with the distribution.
20  * 3. All advertising materials mentioning features or use of this software
21  *    must display the following acknowledgement:
22  *      This product includes software developed by the University of
23  *      California, Berkeley and its contributors.
24  * 4. Neither the name of the University nor the names of its contributors
25  *    may be used to endorse or promote products derived from this software
26  *    without specific prior written permission.
27  *
28  * THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND
29  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31  * ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
32  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
33  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
34  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
35  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
36  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
37  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
38  * SUCH DAMAGE.
39  *
40  * @(#)rsh.c    8.3 (Berkeley) 4/6/94
41  * $FreeBSD: src/usr.bin/rsh/rsh.c,v 1.35 2005/05/21 09:55:07 ru Exp $
42  * $DragonFly: src/usr.bin/rsh/rsh.c,v 1.7 2007/05/18 17:05:12 dillon Exp $
43  */
44
45 #include <sys/param.h>
46 #include <sys/signal.h>
47 #include <sys/socket.h>
48 #include <sys/ioctl.h>
49 #include <sys/file.h>
50 #include <sys/time.h>
51
52 #include <netinet/in.h>
53 #include <netdb.h>
54
55 #include <err.h>
56 #include <errno.h>
57 #include <libutil.h>
58 #include <paths.h>
59 #include <pwd.h>
60 #include <signal.h>
61 #include <stdio.h>
62 #include <stdlib.h>
63 #include <string.h>
64 #include <unistd.h>
65
66 /*
67  * rsh - remote shell
68  */
69 int     rfd2;
70
71 int family = PF_UNSPEC;
72 char rlogin[] = "rlogin";
73
74 void    connect_timeout(int);
75 char   *copyargs(char * const *);
76 void    sendsig(int);
77 void    talk(int, long, pid_t, int, int);
78 void    usage(void);
79
80 int
81 main(int argc, char **argv)
82 {
83         struct passwd const *pw;
84         struct servent const *sp;
85         long omask;
86         int argoff, asrsh, ch, dflag, nflag, one, rem;
87         pid_t pid = 0;
88         uid_t uid;
89         char *args, *host, *p, *user;
90         int timeout = 0;
91
92         argoff = asrsh = dflag = nflag = 0;
93         one = 1;
94         host = user = NULL;
95
96         /* if called as something other than "rsh", use it as the host name */
97         if ((p = strrchr(argv[0], '/')))
98                 ++p;
99         else
100                 p = argv[0];
101         if (strcmp(p, "rsh"))
102                 host = p;
103         else
104                 asrsh = 1;
105
106         /* handle "rsh host flags" */
107         if (!host && argc > 2 && argv[1][0] != '-') {
108                 host = argv[1];
109                 argoff = 1;
110         }
111
112 #define OPTIONS "468Lde:l:nt:w"
113         while ((ch = getopt(argc - argoff, argv + argoff, OPTIONS)) != -1)
114                 switch(ch) {
115                 case '4':
116                         family = PF_INET;
117                         break;
118
119                 case '6':
120                         family = PF_INET6;
121                         break;
122
123                 case 'L':       /* -8Lew are ignored to allow rlogin aliases */
124                 case 'e':
125                 case 'w':
126                 case '8':
127                         break;
128                 case 'd':
129                         dflag = 1;
130                         break;
131                 case 'l':
132                         user = optarg;
133                         break;
134                 case 'n':
135                         nflag = 1;
136                         break;
137                 case 't':
138                         timeout = atoi(optarg);
139                         break;
140                 case '?':
141                 default:
142                         usage();
143                 }
144         optind += argoff;
145
146         /* if haven't gotten a host yet, do so */
147         if (!host && !(host = argv[optind++]))
148                 usage();
149
150         /* if no further arguments, must have been called as rlogin. */
151         if (!argv[optind]) {
152                 if (asrsh)
153                         *argv = rlogin;
154                 execv(_PATH_RLOGIN, argv);
155                 err(1, "can't exec %s", _PATH_RLOGIN);
156         }
157
158         argc -= optind;
159         argv += optind;
160
161         if (!(pw = getpwuid(uid = getuid())))
162                 errx(1, "unknown user id");
163         if (!user)
164                 user = pw->pw_name;
165
166         args = copyargs(argv);
167
168         sp = NULL;
169         if (sp == NULL)
170                 sp = getservbyname("shell", "tcp");
171         if (sp == NULL)
172                 errx(1, "shell/tcp: unknown service");
173
174         if (timeout) {
175                 signal(SIGALRM, connect_timeout);
176                 alarm(timeout);
177         }
178         rem = rcmd_af(&host, sp->s_port, pw->pw_name, user, args, &rfd2,
179                       family);
180         if (timeout) {
181                 signal(SIGALRM, SIG_DFL);
182                 alarm(0);
183         }
184
185         if (rem < 0)
186                 exit(1);
187
188         if (rfd2 < 0)
189                 errx(1, "can't establish stderr");
190         if (dflag) {
191                 if (setsockopt(rem, SOL_SOCKET, SO_DEBUG, &one,
192                     sizeof(one)) < 0)
193                         warn("setsockopt");
194                 if (setsockopt(rfd2, SOL_SOCKET, SO_DEBUG, &one,
195                     sizeof(one)) < 0)
196                         warn("setsockopt");
197         }
198
199         setuid(uid);
200         omask = sigblock(sigmask(SIGINT)|sigmask(SIGQUIT)|sigmask(SIGTERM));
201         if (signal(SIGINT, SIG_IGN) != SIG_IGN)
202                 signal(SIGINT, sendsig);
203         if (signal(SIGQUIT, SIG_IGN) != SIG_IGN)
204                 signal(SIGQUIT, sendsig);
205         if (signal(SIGTERM, SIG_IGN) != SIG_IGN)
206                 signal(SIGTERM, sendsig);
207
208         if (!nflag) {
209                 pid = fork();
210                 if (pid < 0)
211                         err(1, "fork");
212         }
213         else
214                 shutdown(rem, SHUT_WR);
215
216         ioctl(rfd2, FIONBIO, &one);
217         ioctl(rem, FIONBIO, &one);
218
219         talk(nflag, omask, pid, rem, timeout);
220
221         if (!nflag)
222                 kill(pid, SIGKILL);
223         exit(0);
224 }
225
226 void
227 talk(int nflag, long omask, pid_t pid, int rem, int timeout)
228 {
229         int cc, wc;
230         fd_set readfrom, ready, rembits;
231         char buf[BUFSIZ];
232         const char *bp;
233         struct timeval tvtimeout;
234         int nfds, srval;
235
236         if (!nflag && pid == 0) {
237                 close(rfd2);
238
239 reread:         errno = 0;
240                 if ((cc = read(STDIN_FILENO, buf, sizeof(buf))) <= 0)
241                         goto done;
242                 bp = buf;
243
244 rewrite:
245                 if (rem >= FD_SETSIZE)
246                         errx(1, "descriptor too big");
247                 FD_ZERO(&rembits);
248                 FD_SET(rem, &rembits);
249                 nfds = rem + 1;
250                 if (select(nfds, 0, &rembits, 0, 0) < 0) {
251                         if (errno != EINTR)
252                                 err(1, "select");
253                         goto rewrite;
254                 }
255                 if (!FD_ISSET(rem, &rembits))
256                         goto rewrite;
257                 wc = write(rem, bp, cc);
258                 if (wc < 0) {
259                         if (errno == EWOULDBLOCK)
260                                 goto rewrite;
261                         goto done;
262                 }
263                 bp += wc;
264                 cc -= wc;
265                 if (cc == 0)
266                         goto reread;
267                 goto rewrite;
268 done:
269                 shutdown(rem, SHUT_WR);
270                 exit(0);
271         }
272
273         tvtimeout.tv_sec = timeout;
274         tvtimeout.tv_usec = 0;
275
276         sigsetmask(omask);
277         if (rfd2 >= FD_SETSIZE || rem >= FD_SETSIZE)
278                 errx(1, "descriptor too big");
279         FD_ZERO(&readfrom);
280         FD_SET(rfd2, &readfrom);
281         FD_SET(rem, &readfrom);
282         nfds = MAX(rfd2+1, rem+1);
283         do {
284                 ready = readfrom;
285                 if (timeout) {
286                         srval = select(nfds, &ready, 0, 0, &tvtimeout);
287                 } else {
288                         srval = select(nfds, &ready, 0, 0, 0);
289                 }
290
291                 if (srval < 0) {
292                         if (errno != EINTR)
293                                 err(1, "select");
294                         continue;
295                 }
296                 if (srval == 0)
297                         errx(1, "timeout reached (%d seconds)", timeout);
298                 if (FD_ISSET(rfd2, &ready)) {
299                         errno = 0;
300                         cc = read(rfd2, buf, sizeof(buf));
301                         if (cc <= 0) {
302                                 if (errno != EWOULDBLOCK)
303                                         FD_CLR(rfd2, &readfrom);
304                         } else
305                                 write(STDERR_FILENO, buf, cc);
306                 }
307                 if (FD_ISSET(rem, &ready)) {
308                         errno = 0;
309                         cc = read(rem, buf, sizeof(buf));
310                         if (cc <= 0) {
311                                 if (errno != EWOULDBLOCK)
312                                         FD_CLR(rem, &readfrom);
313                         } else
314                                 write(STDOUT_FILENO, buf, cc);
315                 }
316         } while (FD_ISSET(rfd2, &readfrom) || FD_ISSET(rem, &readfrom));
317 }
318
319 void
320 connect_timeout(int sig __unused)
321 {
322         char message[] = "timeout reached before connection completed.\n";
323
324         write(STDERR_FILENO, message, sizeof(message) - 1);
325         _exit(1);
326 }
327
328 void
329 sendsig(int sig)
330 {
331         char signo;
332
333         signo = sig;
334         write(rfd2, &signo, 1);
335 }
336
337 char *
338 copyargs(char * const *argv)
339 {
340         int cc;
341         char *args, *p;
342         char * const *ap;
343
344         cc = 0;
345         for (ap = argv; *ap; ++ap)
346                 cc += strlen(*ap) + 1;
347         if (!(args = malloc((u_int)cc)))
348                 err(1, NULL);
349         for (p = args, ap = argv; *ap; ++ap) {
350                 strcpy(p, *ap);
351                 for (p = strcpy(p, *ap); *p; ++p);
352                 if (ap[1])
353                         *p++ = ' ';
354         }
355         return (args);
356 }
357
358 void
359 usage(void)
360 {
361
362         fprintf(stderr,
363             "usage: rsh [-46dn] [-l username] [-t timeout] host [command]\n");
364         exit(1);
365 }